pad_op.cc 5.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14

S
sneaxiy 已提交
15
#include <memory>
16
#include "paddle/fluid/framework/op_registry.h"
17
#include "paddle/fluid/platform/complex.h"
W
wanghaoshuang 已提交
18 19 20 21 22 23 24 25 26 27

namespace paddle {
namespace operators {

using framework::Tensor;

class PadOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

28
  void InferShape(framework::InferShapeContext* ctx) const override {
29 30
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad");
Q
Qiao Longfei 已提交
31 32

    auto x_dim = ctx->GetInputDim("X");
S
sneaxiy 已提交
33
    auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
34 35 36 37 38 39 40
    PADDLE_ENFORCE_EQ(
        static_cast<int>(paddings.size()), x_dim.size() * 2,
        platform::errors::InvalidArgument(
            "Size of 'paddings' dimension should be equal to 2 * size of "
            "Input(X)'s dimension, but received (size of 'paddings' dimension "
            "is) %d vs (2 * size of Input(X)'s dimension is) %d.",
            static_cast<int>(paddings.size()), x_dim.size() * 2));
S
SunGaofeng 已提交
41
    for (size_t i = 0; i < paddings.size(); ++i) {
42 43 44 45 46
      PADDLE_ENFORCE_GE(paddings[i], 0,
                        platform::errors::InvalidArgument(
                            "The element of 'paddings' should >= 0, but "
                            "received %d for index %d.",
                            paddings[i], static_cast<int>(i)));
S
SunGaofeng 已提交
47
    }
W
wanghaoshuang 已提交
48
    std::vector<int64_t> out_dims(x_dim.size());
W
wanghaoshuang 已提交
49
    for (int i = 0; i < x_dim.size(); ++i) {
50 51 52 53 54
      if ((!ctx->IsRuntime()) && (x_dim[i] == -1)) {
        out_dims[i] = -1;
      } else {
        out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
      }
W
wanghaoshuang 已提交
55
    }
56
    ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
D
Fix bug  
dangqingqing 已提交
57 58 59
    if (out_dims[0] == x_dim[0]) {
      // Only pass LoD when the first dimension is equal between
      // output and input.
Q
Qiao Longfei 已提交
60
      ctx->ShareLoD("X", /*->*/ "Out");
D
Fix bug  
dangqingqing 已提交
61
    }
W
wanghaoshuang 已提交
62 63 64
  }
};

W
wanghaoshuang 已提交
65
class PadOpMaker : public framework::OpProtoAndCheckerMaker {
W
wanghaoshuang 已提交
66
 public:
Y
Yu Yang 已提交
67
  void Make() override {
W
wanghaoshuang 已提交
68 69 70 71
    AddInput("X",
             "The input of pad op. "
             "The input should be a k-D tensor(k > 0 and k < 7)");
    AddOutput("Out",
K
kexinzhao 已提交
72
              "The output of pad op. "
73
              "A tensor with the same shape as X.");
K
kexinzhao 已提交
74 75 76 77 78 79 80 81 82 83 84 85
    AddAttr<std::vector<int>>(
        "paddings",
        "(vector<int>) "
        "A list<int> to describe the padding rules for each dimension. "
        "For 2-D image tensor, paddings=[0, 1, 2, 3] means "
        "padding 0 row to top, 1 row to bottom, 2 columns to left "
        "and 3 columns to right. Size of paddings should be equal to "
        "2 * dimension size of the input tensor.");
    AddAttr<float>("pad_value",
                   "(float, default 0.0) "
                   "The value to fill the padded areas.")
        .SetDefault(0.0f);
W
wanghaoshuang 已提交
86
    AddComment(R"DOC(
K
kexinzhao 已提交
87 88 89 90
Pad Operator.

Pad input into output, as specified by paddings and pad_value. 
The input should be a k-D tensor(k > 0 and k < 7). As an example:
W
wanghaoshuang 已提交
91 92 93 94

Given:

X = [[1, 2],
K
kexinzhao 已提交
95
     [3, 4]],
W
wanghaoshuang 已提交
96

K
kexinzhao 已提交
97
paddings = [0, 1, 1, 2],
W
wanghaoshuang 已提交
98 99 100

and

K
kexinzhao 已提交
101
pad_value = 0,
Q
Qiao Longfei 已提交
102

K
kexinzhao 已提交
103
we have:
W
wanghaoshuang 已提交
104 105 106 107

Out = [[0, 1, 2, 0, 0]
       [0, 3, 4, 0, 0]
       [0, 0, 0, 0, 0]]
K
kexinzhao 已提交
108

W
wanghaoshuang 已提交
109 110 111 112 113 114 115 116
)DOC");
  }
};

class PadOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

117
  void InferShape(framework::InferShapeContext* ctx) const override {
Q
Qiao Longfei 已提交
118 119
    auto x_grad_name = framework::GradVarName("X");
    if (ctx->HasOutput(x_grad_name)) {
S
sneaxiy 已提交
120 121 122
      auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
      auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
      for (int i = 0; i < dout_dims.size(); ++i) {
123 124 125
        if (ctx->IsRuntime() || (dout_dims[i] != -1)) {
          dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
        }
S
sneaxiy 已提交
126 127
      }
      ctx->SetOutputDim(x_grad_name, dout_dims);
W
wanghaoshuang 已提交
128
    }
W
wanghaoshuang 已提交
129 130 131
  }
};

H
hong 已提交
132 133
template <typename T>
class PadOpGradMaker : public framework::SingleGradOpMaker<T> {
134
 public:
H
hong 已提交
135
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Y
Yu Yang 已提交
136 137

 protected:
138
  void Apply(GradOpPtr<T> bind) const override {
H
hong 已提交
139 140 141
    bind->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    bind->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    bind->SetAttrMap(this->Attrs());
Y
Yu Yang 已提交
142
    bind->SetType("pad_grad");
Y
Yu Yang 已提交
143
  }
144 145
};

C
ceci3 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158
template <typename T>
class PadOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("pad");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetAttrMap(this->Attrs());
  }
};

W
wanghaoshuang 已提交
159 160 161 162
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
163

H
hong 已提交
164 165 166
REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker,
                  ops::PadOpGradMaker<paddle::framework::OpDesc>,
                  ops::PadOpGradMaker<paddle::imperative::OpBase>);
C
ceci3 已提交
167 168 169
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad,
                  ops::PadOpDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::PadOpDoubleGradMaker<paddle::imperative::OpBase>);