pad_op.cc 4.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/pad_op.h"
S
sneaxiy 已提交
16
#include <memory>
W
wanghaoshuang 已提交
17 18 19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {

using framework::Tensor;

class PadOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

27
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
Qiao Longfei 已提交
28 29 30 31 32
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of PadOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of PadOp should not be null.");

    auto x_dim = ctx->GetInputDim("X");
S
sneaxiy 已提交
33
    auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
W
wanghaoshuang 已提交
34
    PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()),
W
wanghaoshuang 已提交
35 36
                      "Size of paddings should be equal to 2 * dimension size "
                      "of input tensor.");
S
SunGaofeng 已提交
37 38 39
    for (size_t i = 0; i < paddings.size(); ++i) {
      PADDLE_ENFORCE_GE(paddings[i], 0, "paddings should >= 0.");
    }
W
wanghaoshuang 已提交
40
    std::vector<int64_t> out_dims(x_dim.size());
W
wanghaoshuang 已提交
41
    for (int i = 0; i < x_dim.size(); ++i) {
42 43 44 45 46
      if ((!ctx->IsRuntime()) && (x_dim[i] == -1)) {
        out_dims[i] = -1;
      } else {
        out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
      }
W
wanghaoshuang 已提交
47
    }
Q
Qiao Longfei 已提交
48
    ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
D
Fix bug  
dangqingqing 已提交
49 50 51
    if (out_dims[0] == x_dim[0]) {
      // Only pass LoD when the first dimension is equal between
      // output and input.
Q
Qiao Longfei 已提交
52
      ctx->ShareLoD("X", /*->*/ "Out");
D
Fix bug  
dangqingqing 已提交
53
    }
W
wanghaoshuang 已提交
54 55 56
  }
};

W
wanghaoshuang 已提交
57
class PadOpMaker : public framework::OpProtoAndCheckerMaker {
W
wanghaoshuang 已提交
58
 public:
Y
Yu Yang 已提交
59
  void Make() override {
W
wanghaoshuang 已提交
60 61 62 63
    AddInput("X",
             "The input of pad op. "
             "The input should be a k-D tensor(k > 0 and k < 7)");
    AddOutput("Out",
K
kexinzhao 已提交
64
              "The output of pad op. "
65
              "A tensor with the same shape as X.");
K
kexinzhao 已提交
66 67 68 69 70 71 72 73 74 75 76 77
    AddAttr<std::vector<int>>(
        "paddings",
        "(vector<int>) "
        "A list<int> to describe the padding rules for each dimension. "
        "For 2-D image tensor, paddings=[0, 1, 2, 3] means "
        "padding 0 row to top, 1 row to bottom, 2 columns to left "
        "and 3 columns to right. Size of paddings should be equal to "
        "2 * dimension size of the input tensor.");
    AddAttr<float>("pad_value",
                   "(float, default 0.0) "
                   "The value to fill the padded areas.")
        .SetDefault(0.0f);
W
wanghaoshuang 已提交
78
    AddComment(R"DOC(
K
kexinzhao 已提交
79 80 81 82
Pad Operator.

Pad input into output, as specified by paddings and pad_value. 
The input should be a k-D tensor(k > 0 and k < 7). As an example:
W
wanghaoshuang 已提交
83 84 85 86

Given:

X = [[1, 2],
K
kexinzhao 已提交
87
     [3, 4]],
W
wanghaoshuang 已提交
88

K
kexinzhao 已提交
89
paddings = [0, 1, 1, 2],
W
wanghaoshuang 已提交
90 91 92

and

K
kexinzhao 已提交
93
pad_value = 0,
Q
Qiao Longfei 已提交
94

K
kexinzhao 已提交
95
we have:
W
wanghaoshuang 已提交
96 97 98 99

Out = [[0, 1, 2, 0, 0]
       [0, 3, 4, 0, 0]
       [0, 0, 0, 0, 0]]
K
kexinzhao 已提交
100

W
wanghaoshuang 已提交
101 102 103 104 105 106 107 108
)DOC");
  }
};

class PadOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

109
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
Qiao Longfei 已提交
110 111
    auto x_grad_name = framework::GradVarName("X");
    if (ctx->HasOutput(x_grad_name)) {
S
sneaxiy 已提交
112 113 114
      auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
      auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
      for (int i = 0; i < dout_dims.size(); ++i) {
115 116 117
        if (ctx->IsRuntime() || (dout_dims[i] != -1)) {
          dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
        }
S
sneaxiy 已提交
118 119
      }
      ctx->SetOutputDim(x_grad_name, dout_dims);
W
wanghaoshuang 已提交
120
    }
W
wanghaoshuang 已提交
121 122 123
  }
};

124 125 126
class PadOpGradMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
Y
Yu Yang 已提交
127 128

 protected:
Y
Yu Yang 已提交
129 130
  std::unique_ptr<framework::OpDesc> Apply() const override {
    auto* bind = new framework::OpDesc();
Y
Yu Yang 已提交
131 132 133
    bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    bind->SetAttrMap(Attrs());
Y
Yu Yang 已提交
134
    bind->SetType("pad_grad");
Y
Yu Yang 已提交
135
    return std::unique_ptr<framework::OpDesc>(bind);
Y
Yu Yang 已提交
136
  }
137 138
};

W
wanghaoshuang 已提交
139 140 141 142
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
143 144 145

REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker, ops::PadOpGradMaker);
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad);
Q
QI JUN 已提交
146 147 148 149
REGISTER_OP_CPU_KERNEL(
    pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
    pad_grad, ops::PadGradKernel<paddle::platform::CPUDeviceContext, float>);