interpolate_op.cc 10.1 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

12
#include "paddle/fluid/operators/interpolate_op.h"
S
sneaxiy 已提交
13
#include <memory>
14
#include <string>
15 16 17 18 19 20 21 22
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using framework::Tensor;

23
class InterpolateOp : public framework::OperatorWithKernel {
24 25 26 27 28 29
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"),
30
                   "Input(X) of InterpolateOp should not be null.");
31
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
32 33 34 35 36 37
                   "Output(Out) of InterpolationOp should not be null.");

    auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
    PADDLE_ENFORCE(
        "bilinear" == interp_method || "nearest" == interp_method,
        "Interpolation method can only be \"bilinear\" or \"nearest\".");
38 39 40 41

    auto dim_x = ctx->GetInputDim("X");  // NCHW format
    PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");

D
dengkaipeng 已提交
42 43 44
    int out_h, out_w;
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
D
dengkaipeng 已提交
45 46 47
      // round down
      out_h = static_cast<int>(dim_x[2] * scale);
      out_w = static_cast<int>(dim_x[3] * scale);
D
dengkaipeng 已提交
48 49 50
      // protect when input shape is -1
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
D
dengkaipeng 已提交
51 52 53
    } else {
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
D
dengkaipeng 已提交
54 55
      PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
      PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
D
dengkaipeng 已提交
56 57
    }

58
    if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
59 60 61 62
      auto out_size_dim = ctx->GetInputDim("OutSize");
      PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
                        "OutSize's dimension size must be 1");
      PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
63 64
      ctx->ShareLoD("X", "Out");
      return;
65
    }
66

D
dengkaipeng 已提交
67 68
    std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
    ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
69 70 71 72 73
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
74 75
    return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
                                   ctx.GetPlace());
76 77 78
  }
};

79
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
80 81 82
 public:
  void Make() override {
    AddInput("X",
83 84
             "The input tensor of interpolate operator, "
             "This is a 4-D tensor with shape of [N,  C, H, w].");
85
    AddInput("OutSize",
86
             "This is a 1-D tensor with two numbers to specify output size. "
87 88
             "The first number is height and the second number is width.")
        .AsDispensable();
89 90 91
    AddOutput("Out",
              "The output tensor of interpolate operator, "
              "This is a 4-D tensor with shape of [N, C, H, W].");
92

93 94
    AddAttr<int>("out_h", "output height of interpolate op.");
    AddAttr<int>("out_w", "output width of interpolate op.");
D
dengkaipeng 已提交
95
    AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
96 97 98 99 100 101
    AddAttr<std::string>("interp_method",
                         "(string, default \"bilinear\"), interpolation "
                         "method, can be \"bilinear\" for "
                         "bilinear interpolation and \"nearest\" for nearest "
                         "neighbor interpolation.")
        .SetDefault("bilinear");
102 103
    AddAttr<bool>(
        "align_corners",
T
Tink_Y 已提交
104
        "an optional bool. Defaults to True. "
105 106
        "If True, the centers of 4 corner pixels of the input and output "
        "tensors are aligned, preserving the values at the corner pixels, "
T
Tink_Y 已提交
107
        "If False, are not aligned")
108 109
        .SetDefault(true);
    AddAttr<int>("align_mode",
T
Tink_Y 已提交
110
                 "(int, default \'1\'), optional for bilinear interpolation, "
T
tink2123 已提交
111 112
                 "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
                 "can be \'1\' for src_idx = scale*dst_index .")
T
tink2123 已提交
113
        .SetDefault(1);
114
    AddComment(R"DOC(
115 116 117 118 119
          This operator samples input X to given output shape by using specified
          interpolation method, the interpolation methods can be \"nearest\"
          for nearest neighbor interpolation and \"bilinear\" for bilinear 
          interpolation.

120
          Nearest neighbor interpolation is to perform nearest neighbor interpolation
121
          in both the 3rd dimention(in height direction) and the 4th dimention(in width 
122 123
          direction) on input tensor.
            
124 125 126 127 128 129
          Bilinear interpolation is an extension of linear interpolation for 
          interpolating functions of two variables (e.g. H-direction and 
          W-direction in this op) on a rectilinear 2D grid. The key idea is 
          to perform linear interpolation first in one direction, and then 
          again in the other direction.

T
tink2123 已提交
130
          Align_corners and align_mode are optinal parameters,the calculation method 
131 132 133 134
          of interpolation can be selected by them.
          
          Example:

T
tink2123 已提交
135
          For scale:
136 137 138 139 140 141 142 143 144 145 146 147
          
            if align_corners = True and out_{size}>1 :

              scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0)
            
            else:
              
              scale_{factor} = float(in_{size}/out_{size})
            
          
          Nearest neighbor interpolation:
          
T
tink2123 已提交
148
          if:
149 150 151 152 153 154 155 156
              align_corners = False

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor
              W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor

T
tink2123 已提交
157
          else:
158 159 160 161 162 163 164 165 166 167
              align_corners = True

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})

          Bilinear interpolation:

T
tink2123 已提交
168
          if:
169 170 171 172 173 174 175 176 177
              align_corners = False , align_mode = 0
              
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


T
tink2123 已提交
178
          else:
179 180 181 182 183 184 185 186 187
           
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

          

188
          For details of nearest neighbor interpolation, please refer to Wikipedia: 
189
          https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
190 191 192

          For details of bilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Bilinear_interpolation
193 194 195 196
         )DOC");
  }
};

197
class InterpolateOpGrad : public framework::OperatorWithKernel {
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should not be null");
    auto dim_x = ctx->GetInputDim("X");
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
S
sneaxiy 已提交
214 215 216
    return framework::OpKernelType(
        ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
        ctx.GetPlace());
217 218 219
  }
};

S
sneaxiy 已提交
220 221 222 223 224 225 226 227 228
class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOp().Type() + "_grad");
    op->SetInput("X", Input("X"));
S
sneaxiy 已提交
229 230 231
    if (ForwardOp().Inputs().count("OutSize") > 0) {
      op->SetInput("OutSize", Input("OutSize"));
    }
S
sneaxiy 已提交
232 233 234 235 236 237 238 239 240 241
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());
    return op;
  }
};

DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(InterpolateGradNoNeedBufferVarsInference,
                                      "X");

242 243 244 245
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
246
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
S
sneaxiy 已提交
247 248 249
                  ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
                  ops::InterpolateGradNoNeedBufferVarsInference);
250
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
S
sneaxiy 已提交
251 252 253
                  ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
                  ops::InterpolateGradNoNeedBufferVarsInference);
254 255 256 257 258 259
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
260 261
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
262
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
263
                       ops::InterpolateGradKernel<double>);