conv_shift_op.cc 7.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
M
Markus Kliegl 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
M
Markus Kliegl 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
M
Markus Kliegl 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
M
Markus Kliegl 已提交
14

Y
Yi Wang 已提交
15 16
#include "paddle/fluid/operators/conv_shift_op.h"
#include "paddle/fluid/framework/eigen.h"
M
Markus Kliegl 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

namespace paddle {
namespace operators {

using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

class ConvShiftOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");

    auto x_dims = ctx->GetInputDim("X");
    auto y_dims = ctx->GetInputDim("Y");
    PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
    PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
    PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
                      "The 1st dimension of Input(X) and Input(Y) should "
                      "be equal.");
    PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
                      "The 2nd dimension of Input(Y) should be odd.");
    PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
                      "The 2nd dimension of Input(Y) should be less than or "
                      "equal to the 2nd dimension of Input(X).");
    ctx->SetOutputDim("Out", x_dims);
    ctx->ShareLoD("X", /*->*/ "Out");
  }
};

class ConvShiftGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should be not null.");

    auto x_grad_name = framework::GradVarName("X");
    if (ctx->HasOutput(x_grad_name)) {
      auto x_dims = ctx->GetInputDim("X");
      ctx->SetOutputDim(x_grad_name, x_dims);
    }

    auto y_grad_name = framework::GradVarName("Y");
    if (ctx->HasOutput(y_grad_name)) {
      auto y_dims = ctx->GetInputDim("Y");
      ctx->SetOutputDim(y_grad_name, y_dims);
    }
  }
};

class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
78
  void Make() override {
M
Markus Kliegl 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    AddInput("X",
             "(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
             "where B is the batch size and M is the data dimension.");
    AddInput("Y",
             "(Tensor, default Tensor<float>), a 2-D tensor with shape B x N, "
             "where B is the batch size and N is the data dimension. N must "
             "be odd.");
    AddOutput("Out",
              "(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
              "i.e., the same shape as X.");
    AddComment(R"DOC(
ConvShift Operator.

A layer for circular convolution of two vectors,
as used in the Neural Turing Machine: https://arxiv.org/abs/1410.5401

The equation is:

97
$$Out[i] = \sum_{j=-(N-1)/2}^{(N-1)/2} X_{i+j} * Y_{j}$$
M
Markus Kliegl 已提交
98

99 100 101 102
where X's index is computed modulo M, and Y's index is computed modulo N.

Both inputs X and Y can carry LoD (Level of Details) information.
However, the output only shares the LoD information with input X.
M
Markus Kliegl 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194

)DOC");
  }
};

template <typename T>
class ConvShiftKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    auto *X = context.Input<Tensor>("X");
    auto *Y = context.Input<Tensor>("Y");
    auto *Out = context.Output<Tensor>("Out");
    Out->mutable_data<T>(context.GetPlace());

    auto x = EigenMatrix<T>::From(*X);
    auto y = EigenMatrix<T>::From(*Y);
    auto out = EigenMatrix<T>::From(*Out);
    out.setZero();

    size_t batch_size = X->dims()[0];
    size_t x_width = X->dims()[1];
    size_t y_width = Y->dims()[1];
    size_t y_half_width = (y_width - 1) / 2;

    for (size_t k = 0; k < batch_size; ++k) {
      for (size_t i = 0; i < x_width; ++i) {
        for (size_t j = 0; j < y_width; ++j) {
          int index = (i + j - y_half_width + x_width) % x_width;
          out(k, i) += x(k, index) * y(k, j);
        }
      }
    }
  }
};

template <typename T>
class ConvShiftGradKernel<platform::CPUPlace, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    auto *X = context.Input<Tensor>("X");
    auto *Y = context.Input<Tensor>("Y");
    auto *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
    auto *dX = context.Output<Tensor>(framework::GradVarName("X"));
    auto *dY = context.Output<Tensor>(framework::GradVarName("Y"));

    auto x = EigenMatrix<T>::From(*X);
    auto y = EigenMatrix<T>::From(*Y);
    auto dout = EigenMatrix<T>::From(*dOut);

    auto x_dims = X->dims();
    auto y_dims = Y->dims();
    size_t batch_size = x_dims[0];
    size_t x_width = x_dims[1];
    size_t y_width = y_dims[1];
    size_t y_half_width = (y_width - 1) / 2;

    // The below trades code duplication for efficiency (keeping the if
    // statement outside of the loop).
    if (dX) {
      dX->mutable_data<T>(context.GetPlace());
      auto dx = EigenMatrix<T>::From(*dX);
      dx.setZero();
      for (size_t k = 0; k < batch_size; ++k) {
        for (size_t i = 0; i < x_width; ++i) {
          for (size_t j = 0; j < y_width; ++j) {
            int index = (i + j - y_half_width + x_width) % x_width;
            dx(k, index) += dout(k, i) * y(k, j);
          }
        }
      }
    }

    if (dY) {
      dY->mutable_data<T>(context.GetPlace());
      auto dy = EigenMatrix<T>::From(*dY);
      dy.setZero();
      for (size_t k = 0; k < batch_size; ++k) {
        for (size_t i = 0; i < x_width; ++i) {
          for (size_t j = 0; j < y_width; ++j) {
            int index = (i + j - y_half_width + x_width) % x_width;
            dy(k, j) += x(k, index) * dout(k, i);
          }
        }
      }
    }
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
195
REGISTER_OPERATOR(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker,
196 197
                  paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(conv_shift_grad, ops::ConvShiftGradOp);
M
Markus Kliegl 已提交
198 199 200 201 202
REGISTER_OP_CPU_KERNEL(conv_shift,
                       ops::ConvShiftKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
    conv_shift_grad,
    ops::ConvShiftGradKernel<paddle::platform::CPUPlace, float>);