sequence_expand_op.cc 10.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h"
16
#include <memory>
W
wanghaoshuang 已提交
17 18 19 20

namespace paddle {
namespace operators {

Y
yangyaming 已提交
21
using framework::LoDTensor;
W
wanghaoshuang 已提交
22

W
wanghaoshuang 已提交
23
class SequenceExpandOp : public framework::OperatorWithKernel {
W
wanghaoshuang 已提交
24 25 26 27 28
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
29 30 31
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceExpand");
    OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "SequenceExpand");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequenceExpand");
Y
yangyaming 已提交
32 33

    auto x_dims = ctx->GetInputDim("X");
Y
yangyaming 已提交
34
    auto out_dims = x_dims;
Y
yangyaming 已提交
35 36
    int ref_level = ctx->Attrs().Get<int>("ref_level");

37 38 39 40 41 42
    PADDLE_ENFORCE_GE(
        x_dims.size(), 2,
        platform::errors::InvalidArgument(
            "Dimension number of Input(X) should be at least 2. But "
            "received: input rank %u, input shape [%s].",
            x_dims.size(), x_dims));
Y
yangyaming 已提交
43 44 45 46 47 48 49 50 51 52

    if (ctx->IsRuntime()) {
      framework::Variable* x_var =
          boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
      framework::Variable* y_var =
          boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Y")[0]);

      auto& x_lod = x_var->Get<LoDTensor>().lod();
      auto& y_lod = y_var->Get<LoDTensor>().lod();

T
tensor-tang 已提交
53
      PADDLE_ENFORCE_LE(x_lod.size(), 1UL,
54 55 56 57 58 59 60 61 62 63 64
                        platform::errors::InvalidArgument(
                            "Level of Input(X)'s lod should not be "
                            "greater than 1. But received: lod level %u.",
                            x_lod.size()));
      PADDLE_ENFORCE_GT(
          y_lod.size(), 0UL,
          platform::errors::InvalidArgument(
              "Level of Input(Y)'s lod should be greater than 0. But "
              "received: lod level %u.",
              y_lod.size()));
      PADDLE_ENFORCE_EQ(
Y
yangyaming 已提交
65 66
          ref_level == -1 ||
              (ref_level >= 0 && ref_level < static_cast<int>(y_lod.size())),
67 68 69 70
          true, platform::errors::InvalidArgument(
                    "Invlid `ref_level`, which should be either equal to -1 "
                    "or in [0, %d), but received `ref_level` = %u.",
                    y_lod.size(), ref_level));
Y
yangyaming 已提交
71 72

      if (ref_level == -1) ref_level = y_lod.size() - 1;
Y
yangyaming 已提交
73

Y
yangyaming 已提交
74
      if (x_lod.size() > 0) {
75 76 77 78 79 80 81 82 83
        PADDLE_ENFORCE_EQ(
            x_lod[0].size(), y_lod[ref_level].size(),
            platform::errors::InvalidArgument(
                "Level number of Input(X)'s lod could be 0. Otherwise "
                "size of Input(X)'s first level lod should be equal to "
                "size of Input(Y)'s referred level lod. But received: "
                "Input(X).lod[0].size() = %u, Input(Y).lod[%d].size() = "
                "%u",
                x_lod[0].size(), ref_level, y_lod[ref_level].size()));
84
      } else {
85 86 87 88 89 90 91 92 93 94
        PADDLE_ENFORCE_EQ(
            x_dims[0], static_cast<int64_t>(y_lod[ref_level].size()) - 1,
            platform::errors::InvalidArgument(
                "When Input(X)'s lod is null, the dims[0] of "
                "Input(X) should match the "
                "size of Input(Y)'s referred level lod. But received "
                "Input(X): input rank %u, input shape [%s]; received "
                "Input(Y).lod[%d].size() - 1 = %d.",
                x_dims.size(), x_dims, ref_level,
                static_cast<int64_t>(y_lod[ref_level].size()) - 1));
Y
yangyaming 已提交
95 96
      }

Y
yangyaming 已提交
97
      int64_t out_first_dim = 0;
Y
yangyaming 已提交
98
      if (y_lod[ref_level].size() <= 1) {
Y
yangyaming 已提交
99 100
        out_first_dim = x_dims[0];
      } else {
Y
yangyaming 已提交
101 102 103 104
        for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
          int x_seq_len = 1;
          if (x_lod.size() == 1) {
            x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
Y
yangyaming 已提交
105
          }
Y
yangyaming 已提交
106 107
          out_first_dim +=
              (y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len;
Y
yangyaming 已提交
108 109
        }
      }
Y
yangyaming 已提交
110
      out_dims[0] = out_first_dim;
Y
yangyaming 已提交
111
    } else {
Y
yangyaming 已提交
112
      out_dims[0] = -1;
Y
yangyaming 已提交
113
    }
D
dzhwinter 已提交
114 115
    ctx->SetOutputDim("Out", out_dims);
    ctx->ShareLoD("X", /*->*/ "Out");
W
wanghaoshuang 已提交
116
  }
117 118 119

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
120 121
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
122
  }
W
wanghaoshuang 已提交
123 124
};

W
wanghaoshuang 已提交
125
class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
W
wanghaoshuang 已提交
126
 public:
Y
Yu Yang 已提交
127
  void Make() override {
W
wanghaoshuang 已提交
128
    AddInput("X",
Y
yangyaming 已提交
129 130
             "(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
             "level is at most 1.");
W
wanghaoshuang 已提交
131
    AddInput("Y",
Y
yangyaming 已提交
132 133
             "(LoDTensor, default LoDTensor<float>) Referred LoDTensor whose "
             "lod (specified level) is referred by Input(X).");
W
wanghaoshuang 已提交
134
    AddOutput("Out",
Y
yangyaming 已提交
135 136
              "(LodTensor, default LoDTensor<float>) Output LoDTensor which is "
              "generated from Input(X) by referring lod of Input(Y).");
Y
yangyaming 已提交
137
    AddAttr<int>("ref_level", "Specify lod level of Input(Y).").SetDefault(-1);
W
wanghaoshuang 已提交
138
    AddComment(R"DOC(
W
wanghaoshuang 已提交
139
Sequence Expand Operator.
W
wanghaoshuang 已提交
140

Y
yangyaming 已提交
141 142 143 144 145 146 147
This operator expands `X` according to specified level lod of `Y`. Current
implementation constaints that lod level of `X` should be at most 1. Attribute
`ref_level` is used to specify which level lod of `Y` is referred to expand `X`.
If set `ref_level` to -1, then last level lod of `Y` would be referred.
Please note, rank of `X` should be at least 2, when the rank exceeds 2, `X`
would be viewed as a 2-D tensor.

148
Following are cases to better explain how this works:
Y
yangyaming 已提交
149

W
wanghaoshuang 已提交
150
Case 1:
W
wanghaoshuang 已提交
151

Y
yangyaming 已提交
152 153 154
Given a 1-level LoDTensor input(X)
    X.lod =  [[0,   2,        4]]
    X.data = [[a], [b], [c], [d]]
W
wanghaoshuang 已提交
155 156 157 158
    X.dims = [4, 1]
and input(Y)
    Y.lod = [[0,    2,    4],
             [0, 3, 6, 7, 8]]
Y
yangyaming 已提交
159 160 161 162
ref_level: 0
then we get 1-level LoDTensor
    Out.lod =  [[0,   2,        4,        6,        8]]
    Out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
W
wanghaoshuang 已提交
163
    Out.dims = [8, 1]
W
wanghaoshuang 已提交
164 165 166

Case 2:

Y
yangyaming 已提交
167 168 169 170 171 172 173 174 175
Given 1-level LoDTensor input(X)
    X.lod =  [[0,   1,        4]]
    X.data = [[a], [b], [c], [d]]
    X.dims = [4, 1]
and input(Y)
    Y.lod = [[0,    2,    4],
             [0, 3, 6, 6, 8]]
ref_level: 0
then we get 1-level LoDTensor
176
    Out.lod =  [[0,   1,   2,        5,             8]]
Y
yangyaming 已提交
177 178 179 180 181
    Out.data = [[a], [a], [b], [c], [d], [b], [c], [d]]
    Out.dims = [8, 1]

Case 3:

W
wanghaoshuang 已提交
182
Given a common Tensor input(X)
Y
yangyaming 已提交
183
    X.data = [[a], [b], [c]]
W
wanghaoshuang 已提交
184 185 186
    X.dims = [3, 1]
and input(Y)
    Y.lod = [[0, 2, 3, 6]]
Y
yangyaming 已提交
187
ref_level: -1
188
then we get a common Tensor
Y
yangyaming 已提交
189
    Out.data = [[a], [a], [b], [c], [c], [c]]
W
wanghaoshuang 已提交
190
    Out.dims = [6, 1]
W
wanghaoshuang 已提交
191

Y
yangyaming 已提交
192
Case 4:
W
wanghaoshuang 已提交
193

W
wanghaoshuang 已提交
194
Given a common Tensor input(X)
W
wanghaoshuang 已提交
195 196 197 198
    X.data = [[a, b], [c, d], [e, f]]
    X.dims = [3, 2]
and input(Y)
    Y.lod = [[0, 2, 3, 6]]
Y
yangyaming 已提交
199 200 201
ref_level: 0
then we get a common LoDTensor
    Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]]
W
wanghaoshuang 已提交
202 203
    Out.dims = [6, 2]

W
wanghaoshuang 已提交
204 205 206 207
)DOC");
  }
};

W
wanghaoshuang 已提交
208
class SequenceExpandOpGrad : public framework::OperatorWithKernel {
W
wanghaoshuang 已提交
209 210 211 212 213
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
214 215 216
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceExpandOpGrad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   framework::GradVarName("Out"), "SequenceExpandOpGrad");
Y
yangyaming 已提交
217

W
wanghaoshuang 已提交
218 219
    auto x_dims = ctx->GetInputDim("X");
    auto x_grad_name = framework::GradVarName("X");
Y
yangyaming 已提交
220

W
wanghaoshuang 已提交
221 222 223 224
    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }
  }
225 226 227

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
228 229 230
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.GetPlace());
231
  }
W
wanghaoshuang 已提交
232 233
};

H
hong 已提交
234 235
template <typename T>
class SequenceExpandOpGradMaker : public framework::SingleGradOpMaker<T> {
236
 public:
H
hong 已提交
237
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
238 239

 protected:
240
  void Apply(GradOpPtr<T> op) const override {
241
    op->SetType("sequence_expand_grad");
H
hong 已提交
242 243 244 245 246
    op->SetInput("X", this->Input("X"));
    op->SetInput("Y", this->Input("Y"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
247 248 249
  }
};

250 251 252
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceExpandOpNoNeedBufferVarsInference,
                                    "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
253 254
    SequenceExpandGradOpNoNeedBufferVarsInference, "X", "Y");

W
wanghaoshuang 已提交
255 256 257 258
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
259 260
REGISTER_OPERATOR(sequence_expand, ops::SequenceExpandOp,
                  ops::SequenceExpandOpMaker,
H
hong 已提交
261 262
                  ops::SequenceExpandOpGradMaker<paddle::framework::OpDesc>,
                  ops::SequenceExpandOpGradMaker<paddle::imperative::OpBase>,
263 264 265
                  ops::SequenceExpandOpNoNeedBufferVarsInference);
REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad,
                  ops::SequenceExpandGradOpNoNeedBufferVarsInference);
Q
QI JUN 已提交
266
REGISTER_OP_CPU_KERNEL(
W
wanghaoshuang 已提交
267
    sequence_expand,
Y
yangyaming 已提交
268 269 270 271
    ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int64_t>);
W
wanghaoshuang 已提交
272
REGISTER_OP_CPU_KERNEL(
W
wanghaoshuang 已提交
273
    sequence_expand_grad,
Y
yangyaming 已提交
274 275 276 277
    ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int64_t>);