sequence_project_op.cc 6.5 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/sequence_project_op.h"

namespace paddle {
namespace operators {

class SequenceProjectOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of SequenceProjectOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of SequenceProjectOp should not be null.");
C
chengduoZH 已提交
30 31
    // PaddingData mast be not empty. Otherwise(EnforceNotMet: enforce numel() >
    // 0 failed, 0 <= 0)
C
chengduoZH 已提交
32 33
    PADDLE_ENFORCE(
        ctx->HasInput("PaddingData"),
C
chengduoZH 已提交
34 35
        "Input(PaddingData) of SequenceProjectOp should not be null.");

C
chengduoZH 已提交
36 37 38 39 40 41 42 43
    auto in_dims = ctx->GetInputDim("X");
    PADDLE_ENFORCE(in_dims.size() == 2, "Input(X) should be 2-D tensor.");

    int context_length = ctx->Attrs().Get<int>("context_length");
    bool padding_trainable = ctx->Attrs().Get<bool>("padding_trainable");
    int context_start = ctx->Attrs().Get<int>("context_start");

    if (padding_trainable) {
44
      framework::DDim padding_dim = ctx->GetInputDim("PaddingData");
C
chengduoZH 已提交
45 46 47 48 49
      int up_pad = std::max(0, -context_start);
      int down_pad = std::max(0, context_start + context_length - 1);
      int total_pad = up_pad + down_pad;
      int input_width = static_cast<int>(in_dims[1]);

50 51
      if (context_start == 0 && context_length == 1) {
        PADDLE_THROW(
C
chengduoZH 已提交
52
            "If context_start is 0 and context_length is 1, padding_trainable "
53 54
            "should be false.");
      }
C
chengduoZH 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
      PADDLE_ENFORCE(padding_dim.size() == 2,
                     "Input(PaddingData) should be 2-D tensor.");
      PADDLE_ENFORCE(
          padding_dim[0] == total_pad && padding_dim[1] == input_width,
          "Input(PaddingData)'s shape is not consistent with 'context_start' "
          "and 'context_length'.");
    }

    in_dims[1] = in_dims[1] * context_length;
    ctx->SetOutputDim("Out", in_dims);
  }
};

class SequenceProjectGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
C
chengduoZH 已提交
75 76
                   "Gradient of output(Out) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("X"), "The input(X) should not be null.");
C
chengduoZH 已提交
77

C
chengduoZH 已提交
78 79
    if (ctx->Attrs().Get<bool>("padding_trainable") &&
        ctx->HasOutput(framework::GradVarName("PaddingData"))) {
80 81
      auto padding_dims = ctx->GetInputDim("PaddingData");
      ctx->SetOutputDim(framework::GradVarName("PaddingData"), padding_dims);
C
chengduoZH 已提交
82
    }
C
chengduoZH 已提交
83 84 85
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
    }
C
chengduoZH 已提交
86 87 88 89 90 91 92 93
  }
};

class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  SequenceProjectOpMaker(framework::OpProto* proto,
                         framework::OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
C
chengduoZH 已提交
94 95 96 97 98 99 100 101 102 103 104
    AddInput("X",
             "(A float LoDTensor) the input of SequenceProjectOp, a vector of "
             "2-D matrix of size (minibatch, number_of_input_features).");
    AddOutput("Out",
              "(A float LoDTensor) the output of SequenceProjectOp, a vector "
              "of 2-D matrix of size (minibatch, number_of_input_features x "
              "context_length).");
    AddInput("PaddingData",
             "(A float LoDTensor) the input of SequenceProjectOp, a vector of "
             "2-D matrix of size (up_pad + down_pad, "
             "number_of_input_features). ");
C
chengduoZH 已提交
105 106 107 108 109 110

    AddAttr<bool>("padding_trainable",
                  "(bool, default false) the padding data of SequenceProjectOp "
                  "is trainable or not.")
        .SetDefault(false);
    AddAttr<int>("context_length",
C
chengduoZH 已提交
111
                 "(int, default 3) the context_length of SequenceProjectOp.")
C
chengduoZH 已提交
112 113 114
        .SetDefault(3)
        .GreaterThan(0);
    AddAttr<int>("context_start",
C
chengduoZH 已提交
115
                 "(int, default 0) the context_start of SequenceProjectOp.")
C
chengduoZH 已提交
116 117
        .SetDefault(0);
    AddAttr<int>("context_stride",
C
chengduoZH 已提交
118 119 120
                 "(int, default 1) the context_stride of SequenceProjectOp. "
                 "Currently, sequence_project_op only support "
                 "context_stride=1.")
C
chengduoZH 已提交
121
        .SetDefault(1)
C
chengduoZH 已提交
122
        .GreaterThan(0);
C
chengduoZH 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

    AddComment(R"DOC(
    SequenceProjectOp projects features of context_length time-steps of each instance.

    For a mini-batch of 2 variable lengths sentences, containing 3, and 1 time-steps:

    Assumed input (X) is a [4, M, N] float LoDTensor, and X->lod()[0] = [0, 3, 4].
    Besides, for the sake of simplicity, we assume M=1 and N=2.

    X = [[a1, a2,
          b1, b2.
          c1, c2]
         [d1, d2]]

    This is to say that input (X) has 4 words and the dimension of each word
    representation is 2.

    - Case1:
C
chengduoZH 已提交
141
    If context_start is -1 and padding_trainable is false, we use zero to pad instead of learned weight to pad,
C
chengduoZH 已提交
142 143 144 145
    and the context_lenth is 3, the output (Out) is:

    Out = [0,  0,  a1, a2, b1, b2;
           a1, a2, b1, b2, c1, c2;
C
chengduoZH 已提交
146 147
           b1, b2, c1, c2, 0,  0;
           0,  0,  d1, d2, 0,  0]
C
chengduoZH 已提交
148 149

    - Case2:
C
chengduoZH 已提交
150 151 152 153 154 155 156
    If context_start is -1 and padding_trainable is true, we use learned weight to pad,
    and the context_lenth is 3, the output (Out) is:

    Out = [w1, w2, a1, a2, b1, b2;
           a1, a2, b1, b2, c1, c2;
           b1, b2, c1, c2, w3, w4;
           w1, w2, d1, d2, w3, w4]
C
chengduoZH 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175

    )DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(sequence_project, ops::SequenceProjectOp,
            ops::SequenceProjectOpMaker, sequence_project_grad,
            ops::SequenceProjectGradOp);

REGISTER_OP_CPU_KERNEL(
    sequence_project,
    ops::SequenceProjectKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
    sequence_project_grad,
    ops::SequenceProjectGradKernel<paddle::platform::CPUPlace, float>);