interpolate_op.cc 14.0 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

12
#include "paddle/fluid/operators/interpolate_op.h"
S
sneaxiy 已提交
13
#include <memory>
14
#include <string>
15 16 17 18 19 20 21 22
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using framework::Tensor;

K
Kaipeng Deng 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

  PADDLE_ENFORCE(
      "bilinear" == interp_method || "nearest" == interp_method,
      "Interpolation method can only be \"bilinear\" or \"nearest\" when "
      "Input(X) dimension is 4");

  int out_h, out_w;
  float scale = ctx->Attrs().Get<float>("scale");
  if (scale > 0) {
    // round down
    out_h = static_cast<int>(dim_x[2] * scale);
    out_w = static_cast<int>(dim_x[3] * scale);
    // protect when input shape is -1
    out_h = out_h > 0 ? out_h : -1;
    out_w = out_w > 0 ? out_w : -1;
  } else {
    out_h = ctx->Attrs().Get<int>("out_h");
    out_w = ctx->Attrs().Get<int>("out_w");
    PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
    PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
    PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
                      "OutSize's dimension size must be 1");
    PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
    ctx->ShareLoD("X", "Out");
    return;
  }

  std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
  ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}

static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

  PADDLE_ENFORCE("trilinear" == interp_method,
                 "Interpolation method can only be \"trilinear\" when Input(X) "
                 "dimension is 5");

  int out_d, out_h, out_w;
  float scale = ctx->Attrs().Get<float>("scale");
  if (scale > 0) {
    // round down
    out_d = static_cast<int>(dim_x[2] * scale);
    out_h = static_cast<int>(dim_x[3] * scale);
    out_w = static_cast<int>(dim_x[4] * scale);
    // protect when input shape is -1
    out_d = out_d > 0 ? out_d : -1;
    out_h = out_h > 0 ? out_h : -1;
    out_w = out_w > 0 ? out_w : -1;
  } else {
    out_d = ctx->Attrs().Get<int>("out_d");
    out_h = ctx->Attrs().Get<int>("out_h");
    out_w = ctx->Attrs().Get<int>("out_w");
    PADDLE_ENFORCE_GT(out_d, 0, "out_d should be greater than 0.");
    PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
    PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
    PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
                      "OutSize's dimension size must be 1");
    PADDLE_ENFORCE_EQ(out_size_dim[0], 3, "OutSize's dim[0] must be 3");
    ctx->ShareLoD("X", "Out");
    return;
  }

  std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_d, out_h, out_w});
  ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}

102
class InterpolateOp : public framework::OperatorWithKernel {
103 104 105 106 107 108
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"),
109
                   "Input(X) of InterpolateOp should not be null.");
110
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
111 112
                   "Output(Out) of InterpolationOp should not be null.");

113
    auto dim_x = ctx->GetInputDim("X");  // NCHW format
K
Kaipeng Deng 已提交
114 115 116 117 118 119 120 121 122
    PADDLE_ENFORCE(dim_x.size() == 4 || dim_x.size() == 5,
                   "Input(X) dimension must be 4 or 5");

    if (dim_x.size() == 4) {
      // shape check for 2D interpolate for input tensor shape NCHW
      Interpolate2DInferShapeCheck(ctx);
    } else {  // dim_x.size() == 5
      // shape check for 3D interpolate for input tensor shape NCDHW
      Interpolate3DInferShapeCheck(ctx);
123 124 125 126 127 128
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
129 130
    return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
                                   ctx.GetPlace());
131 132 133
  }
};

134
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
135 136 137
 public:
  void Make() override {
    AddInput("X",
138
             "The input tensor of interpolate operator, "
K
Kaipeng Deng 已提交
139 140
             "This is a 4-D tensor with shape of [N, C, H, W] or a "
             "5-D tensor with shape of [N, C, D, H, W].");
141
    AddInput("OutSize",
142
             "This is a 1-D tensor with two numbers to specify output size. "
K
Kaipeng Deng 已提交
143 144 145
             "It should be [output_height, output_width] when input is a 4-D "
             "tensor and should be [output_depth, output_height, output_width] "
             "when input is a 5-D tensor.")
146
        .AsDispensable();
147 148
    AddOutput("Out",
              "The output tensor of interpolate operator, "
K
Kaipeng Deng 已提交
149
              "This is a tensor in same rank with Input(X).");
150

K
Kaipeng Deng 已提交
151 152 153
    AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
    AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
    AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
D
dengkaipeng 已提交
154
    AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
155 156 157
    AddAttr<std::string>("interp_method",
                         "(string, default \"bilinear\"), interpolation "
                         "method, can be \"bilinear\" for "
K
Kaipeng Deng 已提交
158 159
                         "bilinear interpolation, \"trilinear\" for trilinear "
                         "interpolation and \"nearest\" for nearest "
160 161
                         "neighbor interpolation.")
        .SetDefault("bilinear");
162 163
    AddAttr<bool>(
        "align_corners",
T
Tink_Y 已提交
164
        "an optional bool. Defaults to True. "
165 166
        "If True, the centers of 4 corner pixels of the input and output "
        "tensors are aligned, preserving the values at the corner pixels, "
T
Tink_Y 已提交
167
        "If False, are not aligned")
168 169
        .SetDefault(true);
    AddAttr<int>("align_mode",
T
Tink_Y 已提交
170
                 "(int, default \'1\'), optional for bilinear interpolation, "
T
tink2123 已提交
171 172
                 "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
                 "can be \'1\' for src_idx = scale*dst_index .")
T
tink2123 已提交
173
        .SetDefault(1);
174
    AddComment(R"DOC(
175 176 177 178 179
          This operator samples input X to given output shape by using specified
          interpolation method, the interpolation methods can be \"nearest\"
          for nearest neighbor interpolation and \"bilinear\" for bilinear 
          interpolation.

180
          Nearest neighbor interpolation is to perform nearest neighbor interpolation
181
          in both the 3rd dimention(in height direction) and the 4th dimention(in width 
182 183
          direction) on input tensor.
            
184 185 186 187 188 189
          Bilinear interpolation is an extension of linear interpolation for 
          interpolating functions of two variables (e.g. H-direction and 
          W-direction in this op) on a rectilinear 2D grid. The key idea is 
          to perform linear interpolation first in one direction, and then 
          again in the other direction.

K
Kaipeng Deng 已提交
190 191 192 193 194
          Trilinear interpolation is an extension of linear interpolation for 
          interpolating functions of three variables (e.g. D-direction, 
          H-direction and W-direction in this op) on a rectilinear 3D grid. 
          The linear interpolation is performed on three directions.

T
tink2123 已提交
195
          Align_corners and align_mode are optinal parameters,the calculation method 
196 197 198 199
          of interpolation can be selected by them.
          
          Example:

T
tink2123 已提交
200
          For scale:
201 202 203 204 205 206 207 208 209 210 211 212
          
            if align_corners = True and out_{size}>1 :

              scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0)
            
            else:
              
              scale_{factor} = float(in_{size}/out_{size})
            
          
          Nearest neighbor interpolation:
          
T
tink2123 已提交
213
          if:
214 215 216 217 218 219 220 221
              align_corners = False

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor
              W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor

T
tink2123 已提交
222
          else:
223 224 225 226 227 228 229 230 231 232
              align_corners = True

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})

          Bilinear interpolation:

T
tink2123 已提交
233
          if:
234 235 236 237 238 239 240 241 242
              align_corners = False , align_mode = 0
              
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


T
tink2123 已提交
243
          else:
244 245 246 247 248 249 250
           
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

K
Kaipeng Deng 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
          Trilinear interpolation:

          if:
              align_corners = False , align_mode = 0
              
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              
              D_out = (D_{in}+0.5) * scale_{factor} - 0.5
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


          else:
           
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:

              D_out = D_{in} * scale_{factor}
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
272 273
          

274
          For details of nearest neighbor interpolation, please refer to Wikipedia: 
275
          https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
276 277 278

          For details of bilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Bilinear_interpolation
K
Kaipeng Deng 已提交
279 280 281

          For details of trilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Trilinear_interpolation
282 283 284 285
         )DOC");
  }
};

286
class InterpolateOpGrad : public framework::OperatorWithKernel {
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should not be null");
    auto dim_x = ctx->GetInputDim("X");
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
S
sneaxiy 已提交
303 304 305
    return framework::OpKernelType(
        ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
        ctx.GetPlace());
306 307 308
  }
};

S
sneaxiy 已提交
309 310 311 312 313 314 315 316 317
class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOp().Type() + "_grad");
    op->SetInput("X", Input("X"));
S
sneaxiy 已提交
318 319 320
    if (ForwardOp().Inputs().count("OutSize") > 0) {
      op->SetInput("OutSize", Input("OutSize"));
    }
S
sneaxiy 已提交
321 322 323 324 325 326 327 328 329 330
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());
    return op;
  }
};

DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(InterpolateGradNoNeedBufferVarsInference,
                                      "X");

331 332 333 334
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
335
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
S
sneaxiy 已提交
336 337 338
                  ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
                  ops::InterpolateGradNoNeedBufferVarsInference);
339
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
S
sneaxiy 已提交
340 341 342
                  ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
                  ops::InterpolateGradNoNeedBufferVarsInference);
K
Kaipeng Deng 已提交
343 344 345 346
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
                  ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
                  ops::InterpolateGradNoNeedBufferVarsInference);
347 348 349 350 351 352
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
353 354
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
355
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
356
                       ops::InterpolateGradKernel<double>);
K
Kaipeng Deng 已提交
357 358 359 360 361
REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);