sequence_pad_op.cc 11.6 KB
Newer Older
Y
yangyaming 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/sequence_ops/sequence_pad_op.h"
16

17 18
#include <memory>
#include <string>
Y
yangyaming 已提交
19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {

class SequencePadOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

27
 protected:
Y
yangyaming 已提交
28
  void InferShape(framework::InferShapeContext* ctx) const override {
29 30
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"),
                      true,
31 32 33
                      platform::errors::NotFound(
                          "Input(X) of SequencePadOp should not be null."));
    PADDLE_ENFORCE_EQ(
34 35
        ctx->HasInput("PadValue"),
        true,
36 37
        platform::errors::NotFound(
            "Input(PadValue) of SequencePadOp should not be null."));
38 39
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"),
                      true,
40 41 42
                      platform::errors::NotFound(
                          "Output(Out) of SequencePadOp should not be null."));
    PADDLE_ENFORCE_EQ(
43 44
        ctx->HasOutput("Length"),
        true,
45 46
        platform::errors::NotFound(
            "Output(Length) of SequencePadOp should not be null."));
Y
yangyaming 已提交
47 48

    auto x_dims = ctx->GetInputDim("X");
49 50
    PADDLE_ENFORCE_GE(x_dims.size(),
                      2,
51 52 53 54
                      platform::errors::InvalidArgument(
                          "The rank of SequencePadOp Input(X) can't be less "
                          "than 2. But the rank we received is %d",
                          x_dims.size()));
55
    auto time_step_dims = phi::slice_ddim(x_dims, 1, x_dims.size());
56
    auto pad_value_dims = ctx->GetInputDim("PadValue");
57
    PADDLE_ENFORCE_EQ(
58
        pad_value_dims == phi::make_ddim({1}) ||
59
            pad_value_dims == phi::make_ddim({}) ||
60 61 62 63 64
            pad_value_dims == time_step_dims,
        true,
        platform::errors::InvalidArgument(
            "The SequencePadOp Input(PadValue) must be a scalar or a tensor "
            "whose shape equals to time steps in sequences"));
Y
yangyaming 已提交
65

F
fengjiayi 已提交
66
    int out_dim_0 = -1;
Y
yangyaming 已提交
67

68
    int padded_length = ctx->Attrs().Get<int>("padded_length");
Y
yangyaming 已提交
69
    if (ctx->IsRuntime()) {
70
      // run time
Y
yangyaming 已提交
71
      framework::Variable* x_var =
R
Ruibiao Chen 已提交
72
          PADDLE_GET(framework::Variable*, ctx->GetInputVarPtrs("X")[0]);
柠檬味~ 已提交
73
      const auto& x_lod = x_var->Get<phi::DenseTensor>().lod();
74 75
      PADDLE_ENFORCE_EQ(x_lod.empty(),
                        false,
76 77
                        platform::errors::NotFound(
                            "The SequencePadOp Input(X) must hold lod info."));
78
      const auto& x_lod_0 = x_lod[0];
79
      PADDLE_ENFORCE_GE(
80 81
          x_lod_0.size(),
          2,
82 83 84 85
          platform::errors::InvalidArgument(
              "The size of SequencePadOp Input(X)'s lod info can't be less "
              "than 2. But the size we received is %d",
              x_lod_0.size()));
86 87
      PADDLE_ENFORCE_EQ(x_dims[0],
                        static_cast<int64_t>(x_lod_0.back()),
88 89 90 91 92
                        platform::errors::InvalidArgument(
                            "The SequencePadOp Input(X)'s lod info mismatches "
                            "the actual tensor shape. The 1st dimension of "
                            "Input(X)'s lod info is %d, the 1st dimension of "
                            "actual tensor shape is %d",
93 94
                            x_dims[0],
                            static_cast<int64_t>(x_lod_0.back())));
95 96 97 98 99

      int seq_num = x_lod_0.size() - 1;
      int max_seq_len = math::MaximumSequenceLength(x_lod_0);
      if (padded_length == -1) {
        padded_length = max_seq_len;
Y
yangyaming 已提交
100
      }
101
      PADDLE_ENFORCE_GE(
102 103
          padded_length,
          max_seq_len,
104 105 106 107 108 109
          platform::errors::InvalidArgument(
              "The SequencePadOp Attr(padded_length) should be greater than or "
              "equal to the "
              "length of the longest original sequence. But the padded_length "
              "we received is %d, the length of the longest original sequence "
              "is %d",
110 111
              padded_length,
              max_seq_len));
F
fengjiayi 已提交
112
      out_dim_0 = seq_num;
Y
yangyaming 已提交
113
    } else {
114
      // compile time
115 116 117
      if (padded_length == -1) {
        padded_length = 1;
      }
118
      PADDLE_ENFORCE_GT(
119 120
          ctx->GetLoDLevel("X"),
          0,
121 122 123 124
          platform::errors::InvalidArgument(
              "The LoD level of SequencePadOp Input(X) should be "
              "larger than 0. But the LoD level we received is %d",
              ctx->GetLoDLevel("X")));
Y
yangyaming 已提交
125 126
    }

127
    std::vector<int> out_dims_vec{out_dim_0, padded_length};
128
    std::vector<int> len_dims_vec{out_dim_0};
129
    auto time_step_dims_vec = phi::vectorize<int>(time_step_dims);
130 131
    out_dims_vec.insert(out_dims_vec.end(),
                        time_step_dims_vec.begin(),
F
fengjiayi 已提交
132
                        time_step_dims_vec.end());
133 134
    ctx->SetOutputDim("Out", phi::make_ddim(out_dims_vec));
    ctx->SetOutputDim("Length", phi::make_ddim(len_dims_vec));
135 136 137
  }

 protected:
138
  phi::KernelKey GetExpectedKernelType(
139
      const framework::ExecutionContext& ctx) const override {
140
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
141
    return phi::KernelKey(data_type, ctx.GetPlace());
Y
yangyaming 已提交
142 143 144 145 146
  }
};

class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
147
  void Make() override {
Y
yangyaming 已提交
148
    AddInput("X",
柠檬味~ 已提交
149 150
             "(phi::DenseTensor, default phi::DenseTensor<float>) Input "
             "variable which "
151 152
             "should contain lod information.");
    AddInput("PadValue",
柠檬味~ 已提交
153 154
             "(phi::DenseTensor), this phi::DenseTensor holds values that will "
             "be fill into "
155 156 157
             "padded steps. It can be a scalar or a tensor whose shape equals "
             "to time steps in sequences. If it's a scalar, it will be "
             "automatically broadcasted to the shape of time step.");
柠檬味~ 已提交
158 159 160 161 162 163 164
    AddOutput("Out",
              "(phi::DenseTensor) The output vairable, which contains padded "
              "sequences.");
    AddOutput("Length",
              "(phi::DenseTensor) The output vairable, which contains the "
              "actual length of "
              "sequences before padding.");
165 166
    AddAttr<int>(
        "padded_length",
T
tianshuo78520a 已提交
167
        "The length of padded sequences. It can be set to -1 or "
168 169 170 171 172
        "any positive int. When it is -1, all sequences will be padded up to "
        "the length of the longest one among them; when it a certain positive "
        "value, it must be greater than the length of the longest original "
        "sequence.")
        .SetDefault(-1);
Y
yangyaming 已提交
173
    AddComment(R"DOC(
F
fengjiayi 已提交
174 175
      Sequence Pad Operator

176 177 178
      This operator pads sequences in a same batch to a consistent length.
      The length is specified by attribute 'padded_length'. New elements,
      whose values are specified by input 'PadValue', will be appended to
F
fengjiayi 已提交
179 180 181 182 183 184
      the end of each sequence, to make their final lengths consistent.

      Following are cases to better explain how this works:

      Case 1:

柠檬味~ 已提交
185
      Given a 1-level phi::DenseTensor input(X):
F
fengjiayi 已提交
186 187 188 189 190
          X.lod = [[0, 2,       5]]
          X.data = [a, b, c, d, e]
      and Input(PadValue):
          PadValue.data = [0]
      and attribite 'padded_length' = 4,
柠檬味~ 已提交
191
      then we get phi::DenseTensor:
192
          Out.data = [[a, b, 0, 0],
F
fengjiayi 已提交
193
                      [c, d, e, 0]]
194
          Length.data = [2, 3]
195

F
fengjiayi 已提交
196 197
      Case 2:

柠檬味~ 已提交
198
      Given a 1-level phi::DenseTensor input(X):
F
fengjiayi 已提交
199 200 201 202
          X.lod = [[0,               2,                           5]]
          X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]]
      and Input(PadValue):
          PadValue.data = [0]
203
      and attribite 'padded_length' = -1, which mean using the length
F
fengjiayi 已提交
204
      of longest input sequence(3 in this case),
柠檬味~ 已提交
205
      then we get phi::DenseTensor:
206
          Out.data = [[[a1, a2], [b1, b2], [0, 0]],
F
fengjiayi 已提交
207
                      [[c1, c2], [d1, d2], [e1, e2]]]
208
          Length.data = [2, 3]
209

F
fengjiayi 已提交
210 211
      Case 3:

柠檬味~ 已提交
212
      Given a 1-level phi::DenseTensor input(X):
F
fengjiayi 已提交
213 214 215 216
          X.lod = [[0,               2,                           5]]
          X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]]
      and Input(PadValue):
          PadValue.data = [p1, p2]
217
      and attribite 'padded_length' = -1, which mean using the length
F
fengjiayi 已提交
218
      of longest input sequence(3 in this case),
柠檬味~ 已提交
219
      then we get phi::DenseTensor:
220
          Out.data = [[[a1, a2], [b1, b2], [p1, p2]],
F
fengjiayi 已提交
221
                      [[c1, c2], [d1, d2], [e1, e2]]]
222
          Length.data = [2, 3]
Y
yangyaming 已提交
223 224 225 226 227 228 229 230 231 232

    )DOC");
  }
};

class SequencePadGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
233 234
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"),
                      true,
235 236
                      platform::errors::NotFound(
                          "Input(X) of SequencePadGradOp should not be null."));
237
    PADDLE_ENFORCE_EQ(
238 239
        ctx->HasInput(framework::GradVarName("Out")),
        true,
240 241
        platform::errors::NotFound(
            "Input(Out@GRAD) of SequencePadGradOp should not be null."));
Y
yangyaming 已提交
242 243 244 245 246 247

    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
      ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
    }
  }
248 249

 protected:
250
  phi::KernelKey GetExpectedKernelType(
251
      const framework::ExecutionContext& ctx) const override {
252 253
    auto data_type = OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));
254
    return phi::KernelKey(data_type, ctx.GetPlace());
255
  }
Y
yangyaming 已提交
256 257
};

H
hong 已提交
258 259
template <typename T>
class SequencePadGradOpMaker : public framework::SingleGradOpMaker<T> {
260
 public:
H
hong 已提交
261
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
262 263

 protected:
264
  void Apply(GradOpPtr<T> op) const override {
265
    op->SetType("sequence_pad_grad");
H
hong 已提交
266 267 268 269
    op->SetAttrMap(this->Attrs());
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
270 271 272
  }
};

273
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePadGradOpNoNeedBufferVarsInferer,
274
                                    "X");
275

Y
yangyaming 已提交
276 277 278 279
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
280 281 282
REGISTER_OPERATOR(sequence_pad,
                  ops::SequencePadOp,
                  ops::SequencePadOpMaker,
H
hong 已提交
283 284
                  ops::SequencePadGradOpMaker<paddle::framework::OpDesc>,
                  ops::SequencePadGradOpMaker<paddle::imperative::OpBase>);
285 286
REGISTER_OPERATOR(sequence_pad_grad,
                  ops::SequencePadGradOp,
287
                  ops::SequencePadGradOpNoNeedBufferVarsInferer);
L
Leo Chen 已提交
288 289 290 291 292 293 294 295 296 297
REGISTER_OP_CPU_KERNEL(sequence_pad,
                       ops::SequencePadOpKernel<phi::CPUContext, float>,
                       ops::SequencePadOpKernel<phi::CPUContext, double>,
                       ops::SequencePadOpKernel<phi::CPUContext, int>,
                       ops::SequencePadOpKernel<phi::CPUContext, int64_t>);
REGISTER_OP_CPU_KERNEL(sequence_pad_grad,
                       ops::SequencePadGradOpKernel<phi::CPUContext, float>,
                       ops::SequencePadGradOpKernel<phi::CPUContext, double>,
                       ops::SequencePadGradOpKernel<phi::CPUContext, int>,
                       ops::SequencePadGradOpKernel<phi::CPUContext, int64_t>);