sequence_slice_op.cc 6.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/sequence_ops/sequence_slice_op.h"
16
#include <memory>
17 18 19 20

namespace paddle {
namespace operators {

21
class SequenceSliceOp : public framework::OperatorWithKernel {
22 23 24 25
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
26 27 28 29
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceSlice");
    OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset", "SequenceSlice");
    OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "SequenceSlice");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequenceSlice");
30 31
    auto input_dims = ctx->GetInputDim("X");

32 33 34
    auto offset_dim = ctx->GetInputDim("Offset");
    auto length_dim = ctx->GetInputDim("Length");

W
wanghaox 已提交
35 36
    PADDLE_ENFORCE_EQ(
        offset_dim.size(), 2UL,
37 38 39 40 41
        platform::errors::InvalidArgument(
            "Input Offset dimension error. SequenceSlice operator only support "
            "one level sequence now, the dimension of input Offset must be 2, "
            "but received dimension is %d.",
            offset_dim.size()));
W
wanghaox 已提交
42 43
    PADDLE_ENFORCE_EQ(
        length_dim.size(), 2UL,
44 45 46 47 48
        platform::errors::InvalidArgument(
            "Input Length dimension error. SequenceSlice operator only support "
            "one level sequence now, the dimension of input Length must be 2, "
            "but received dimension is %d.",
            offset_dim.size()));
49

W
wanghaox 已提交
50 51
    // Initialize the output's dims to maximum,
    // and re-set to real dims by the value of Offset and Length at kernel
52
    ctx->SetOutputDim("Out", input_dims);
53
  }
54

55
 protected:
56
  framework::OpKernelType GetExpectedKernelType(
57
      const framework::ExecutionContext& ctx) const override {
58 59 60
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
61 62 63
  }
};

64
class SequenceSliceGradOp : public framework::OperatorWithKernel {
65 66 67 68
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
69 70 71 72
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   framework::GradVarName("Out"), "SequenceSliceGrad");
    OP_INOUT_CHECK(ctx->HasOutputs(framework::GradVarName("X")), "Output",
                   framework::GradVarName("X"), "SequenceSliceGrad");
73 74
    ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
  }
75 76

 protected:
77
  framework::OpKernelType GetExpectedKernelType(
78
      const framework::ExecutionContext& ctx) const override {
79 80 81
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.device_context());
82
  }
83 84
};

85
class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
86
 public:
Y
Yu Yang 已提交
87
  void Make() override {
88 89 90 91 92
    AddInput("X",
             "(LoDTensor), "
             "the input of SequenceSliceOp.");
    AddInput("Offset",
             "(Tensor), "
93 94
             "a vector<int> to describe the offset of every input sequence for "
             "sub sequence item.");
95 96
    AddInput("Length",
             "(Tensor), "
97 98
             "a vector<int> to describe the length of every input sequence for "
             "sub sequence item.");
99
    AddOutput("Out", "(LoDTensor), the output of SequenceSliceOp.");
100
    AddComment(R"DOC(
101
Sequence slice operator
102

W
wanghaox 已提交
103
The operator crops a subsequence from given sequence with given start offset and subsequence length.
104 105
It only supports sequence (LoD Tensor with level number is 1).
- Case:
106 107 108 109 110
    X = [[a1, a2;
        b1, b2;
        c1, c2]
       [d1, d2;
        e1, e2]]
111
    LoD(X) = {{0, 3, 5}}; Dims(X) = (5, 2)
112
    Offset = [[0], [1]]; Length = [[2], [1]]
113 114 115 116

    Out = [[a1, a2;
            b1, b2]
            [e1, e2]]
117
    LoD(Out) = {{0, 2, 3}}; Dims(Out) = (3, 2)
W
wanghaox 已提交
118
NOTE: The first dimension size of input, the size of offset and Length, should be equal. The offset start from 0.
119 120 121 122
    )DOC");
  }
};

H
hong 已提交
123 124
template <typename T>
class SequenceSliceGradOpMaker : public framework::SingleGradOpMaker<T> {
125
 public:
H
hong 已提交
126
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
127 128

 protected:
129
  void Apply(GradOpPtr<T> op) const override {
130
    op->SetType("sequence_slice_grad");
H
hong 已提交
131 132 133 134 135 136
    op->SetInput("X", this->Input("X"));
    op->SetInput("Offset", this->Input("Offset"));
    op->SetInput("Length", this->Input("Length"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
137 138 139
  }
};

140
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequenceSliceGradNoNeedBufferVarsInferer,
141
                                    "X");
142

143 144 145 146
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
147
REGISTER_OPERATOR(sequence_slice, ops::SequenceSliceOp,
H
hong 已提交
148 149 150
                  ops::SequenceSliceOpMaker,
                  ops::SequenceSliceGradOpMaker<paddle::framework::OpDesc>,
                  ops::SequenceSliceGradOpMaker<paddle::imperative::OpBase>);
151
REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp,
152
                  ops::SequenceSliceGradNoNeedBufferVarsInferer);
153
REGISTER_OP_CPU_KERNEL(
154
    sequence_slice,
155 156 157 158
    ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
159
REGISTER_OP_CPU_KERNEL(
160
    sequence_slice_grad,
161 162 163 164 165
    ops::SequenceSliceGradOpKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SequenceSliceGradOpKernel<paddle::platform::CPUDeviceContext, double>,
    ops::SequenceSliceGradOpKernel<paddle::platform::CPUDeviceContext, int>,
    ops::SequenceSliceGradOpKernel<paddle::platform::CPUDeviceContext,
                                   int64_t>);