interpolate_op.cc 27.0 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

12
#include "paddle/fluid/operators/interpolate_op.h"
S
sneaxiy 已提交
13
#include <memory>
14
#include <string>
15 16
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
17 18 19
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
20 21 22 23 24

namespace paddle {
namespace operators {

using framework::Tensor;
25
using DataLayout = framework::DataLayout;
26

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

  PADDLE_ENFORCE_EQ("linear", interp_method,
                    platform::errors::InvalidArgument(
                        "Interpolation method can only be \"linear\" when"
                        "Input(X) dimension is 3, but got method = %s .",
                        interp_method));
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));

  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 1,
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'size of Op(interpolate) must be 1. "
            "Attr(out_shape)'s length must be 1 for 3-D input tensor, but got "
            "size = %d .",
            inputs_name.size()));
    int out_w = ctx->Attrs().Get<int>("out_w");
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_w};
    } else {
      dim_out = {dim_x[0], out_w, dim_x[2]};
    }
    ctx->SetOutputDim("Out", dim_out);

    return;
  }

  int out_w;
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
    PADDLE_ENFORCE_EQ(
        scale_tensor.size(), 1,
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got dimension = %d .",
            scale_tensor.size()));
    out_w = -1;
  } else {
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      // protect when input shape is -1
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_w = ctx->Attrs().Get<int>("out_w");
    }
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
    PADDLE_ENFORCE_EQ(
        out_size_dim.size(), 1,
        platform::errors::InvalidArgument(
            "OutSize's dimension size must be 1, but got dimention = %d .",
            out_size_dim.size()));
    PADDLE_ENFORCE_EQ(out_size_dim[0], 1, platform::errors::InvalidArgument(
                                              "OutSize's dim[0] must be 1"));
    ctx->ShareLoD("X", "Out");
    return;
  }

  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_w};
  } else {
    dim_out = {dim_x[0], out_w, dim_x[2]};
  }
  ctx->SetOutputDim("Out", dim_out);
}

K
Kaipeng Deng 已提交
106 107 108 109
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

110 111 112 113 114 115 116
  PADDLE_ENFORCE_EQ("bilinear" == interp_method || "nearest" == interp_method ||
                        "bicubic" == interp_method,
                    true, platform::errors::InvalidArgument(
                              "Interpolation method can only be \"bilinear\" "
                              "or \"nearest\" or \"bicubic\" when "
                              "Input(X) dimension is 4, but got method is %s.",
                              interp_method));
117 118
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
119

120 121 122 123 124
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 2,
125 126 127 128 129
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'size of Op(interpolate) must be 2. "
            "Attr(out_shape)'s length must be 2 for 4-D input "
            "tensor, but got size = %d .",
            inputs_name.size()));
130 131
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
132 133 134 135 136 137 138
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
    }
    ctx->SetOutputDim("Out", dim_out);
139 140 141 142

    return;
  }

K
Kaipeng Deng 已提交
143
  int out_h, out_w;
144 145
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
146 147 148 149 150
    PADDLE_ENFORCE_EQ(
        scale_tensor.size(), 1,
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got dimension = %d .",
            scale_tensor.size()));
151 152
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
153
  } else {
154 155 156
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
157 158 159 160 161 162
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
163 164 165 166 167 168 169
      // protect when input shape is -1
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
170 171 172 173
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
174 175
    PADDLE_ENFORCE_EQ(
        out_size_dim.size(), 1,
176 177 178
        platform::errors::InvalidArgument("OutSize's dimension size must be 1, "
                                          "but got dimension size is %d .",
                                          out_size_dim.size()));
179 180 181
    PADDLE_ENFORCE_EQ(
        out_size_dim[0], 2,
        platform::errors::InvalidArgument(
182
            "OutSize's dimension[0] must be 2, but got dimension[0] is %d .",
183
            out_size_dim[0]));
K
Kaipeng Deng 已提交
184 185 186 187
    ctx->ShareLoD("X", "Out");
    return;
  }

188 189 190 191 192 193 194
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
195 196 197 198 199 200
}

static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

201 202 203 204 205 206
  PADDLE_ENFORCE_EQ(
      "trilinear", interp_method,
      platform::errors::InvalidArgument(
          "Interpolation method can only be \"trilinear\" when Input(X) "
          "dimension is 5, but got method = %s .",
          interp_method));
207 208
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
209

210 211 212 213 214
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 3,
215 216 217 218 219
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'s size of Op(interpolate) must be 3. "
            "Attr(out_shape)'s length must be 3 for 5-D input "
            "tensor, but got size = %d .",
            inputs_name.size()));
220 221 222
    int out_d = ctx->Attrs().Get<int>("out_d");
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
223 224 225 226 227 228 229
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
    }
    ctx->SetOutputDim("Out", dim_out);
230 231 232 233

    return;
  }

K
Kaipeng Deng 已提交
234
  int out_d, out_h, out_w;
235 236
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
237 238 239 240 241
    PADDLE_ENFORCE_EQ(
        scale_tensor.size(), 1,
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got size = %d .",
            scale_tensor.size()));
242 243 244
    out_d = -1;
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
245
  } else {
246 247 248
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
249 250 251 252 253 254 255 256 257
      out_d = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[4] * scale)
                   : static_cast<int>(dim_x[3] * scale));
258 259 260 261 262 263 264 265 266
      // protect when input shape is -1
      out_d = out_d > 0 ? out_d : -1;
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_d = ctx->Attrs().Get<int>("out_d");
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
267 268 269 270
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
271 272 273 274 275
    PADDLE_ENFORCE_EQ(
        out_size_dim.size(), 1,
        platform::errors::InvalidArgument(
            "OutSize's dimension size must be 1, but got size is %d.",
            out_size_dim.size()));
276
    PADDLE_ENFORCE_EQ(out_size_dim[0], 3,
277 278 279
                      platform::errors::InvalidArgument(
                          "OutSize's dim[0] must be 3, but got size is %d.",
                          out_size_dim[0]));
K
Kaipeng Deng 已提交
280 281 282 283
    ctx->ShareLoD("X", "Out");
    return;
  }

284 285 286 287 288 289 290
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
291 292
}

293
class InterpolateOp : public framework::OperatorWithKernel {
294 295 296 297 298
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
299 300
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Interpolate");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Interpolate");
301

302
    auto dim_x = ctx->GetInputDim("X");  // NCHW format
303 304 305 306 307 308 309 310 311
    PADDLE_ENFORCE(
        dim_x.size() == 3 || dim_x.size() == 4 || dim_x.size() == 5,
        platform::errors::Unimplemented(
            "Input(X) dimension must be 3, 4 or 5, but got dimension = %d .",
            dim_x.size()));
    if (dim_x.size() == 3) {
      // shape check for 1D interpolate for input tensor shape NCHW
      Interpolate1DInferShapeCheck(ctx);
    } else if (dim_x.size() == 4) {
K
Kaipeng Deng 已提交
312 313 314 315 316
      // shape check for 2D interpolate for input tensor shape NCHW
      Interpolate2DInferShapeCheck(ctx);
    } else {  // dim_x.size() == 5
      // shape check for 3D interpolate for input tensor shape NCDHW
      Interpolate3DInferShapeCheck(ctx);
317 318 319 320 321 322
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
323 324
    framework::DataLayout layout = framework::DataLayout::kAnyLayout;
    framework::LibraryType library = framework::LibraryType::kPlain;
325
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
326 327 328 329

#ifdef PADDLE_WITH_MKLDNN
    auto interp_method = ctx.Attr<std::string>("interp_method");
    // TODO(danqing): support other interp_method
330
    if (this->CanMKLDNNBeUsed(ctx, data_type) &&
331 332 333 334 335 336
        (interp_method == "nearest" || interp_method == "bilinear")) {
      layout = framework::DataLayout::kMKLDNN;
      library = framework::LibraryType::kMKLDNN;
    }
#endif

337
    return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
338
  }
339 340 341 342

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
#ifdef PADDLE_WITH_MKLDNN
    if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
        (tensor.layout() != framework::DataLayout::kMKLDNN)) {
      auto attrs = Attrs();
      auto ar = paddle::framework::AttrReader(attrs);
      const std::string data_format = ar.Get<std::string>("data_layout");
      auto dl = framework::StringToDataLayout(data_format);
      // Some models may have intentionally set "AnyLayout" for pool
      // op. Treat this as NCHW (default data_format value)
      if (dl != framework::DataLayout::kAnyLayout) {
        return framework::OpKernelType(expected_kernel_type.data_type_,
                                       tensor.place(), dl);
      }
    }
#endif
358 359 360 361 362 363
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
364 365
};

366
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
367 368 369
 public:
  void Make() override {
    AddInput("X",
370
             "The input tensor of interpolate operator, "
K
Kaipeng Deng 已提交
371 372
             "This is a 4-D tensor with shape of [N, C, H, W] or a "
             "5-D tensor with shape of [N, C, D, H, W].");
373
    AddInput("OutSize",
374
             "This is a 1-D tensor with two numbers to specify output size. "
K
Kaipeng Deng 已提交
375 376
             "It should be [output_height, output_width] when input is a 4-D "
             "tensor and should be [output_depth, output_height, output_width] "
377 378 379 380 381 382 383 384 385 386 387 388 389
             "when input is a 5-D tensor. It has a higher priority than "
             "the attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDispensable();
    AddInput("SizeTensor",
             "(vector<Tensor<int32>>, optional). If provided, interpolate will "
             "use this. The shape of the tensor in vector MUST BE [1]. "
             "It has the highest priority compare with Input(OutSize) and "
             "attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDuplicable()
        .AsDispensable();
    AddInput("Scale",
             "This is a 1-D tensor with one number to specify output scale. "
             "It has the higher priority compare with attr(scale).")
390
        .AsDispensable();
391 392
    AddOutput("Out",
              "The output tensor of interpolate operator, "
K
Kaipeng Deng 已提交
393
              "This is a tensor in same rank with Input(X).");
394

395 396 397 398 399 400 401
    AddAttr<std::string>(
        "data_layout",
        "(string, default NCHW) Only used in "
        "an optional string from: \"NHWC\", \"NCHW\". "
        "Specify that the data format of the input and output data is "
        "channel_first or channel_last.")
        .SetDefault("NCHW");
K
Kaipeng Deng 已提交
402 403 404
    AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
    AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
    AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
D
dengkaipeng 已提交
405
    AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
406 407
    AddAttr<std::string>("interp_method",
                         "(string, default \"bilinear\"), interpolation "
408 409
                         "method, can be \"linear\" for linear interpolation"
                         ",\"bilinear\" for "
K
Kaipeng Deng 已提交
410 411
                         "bilinear interpolation, \"trilinear\" for trilinear "
                         "interpolation and \"nearest\" for nearest "
X
xiaoting 已提交
412 413
                         "neighbor interpolation, and \"bicubic\" for bicubic"
                         "interpolation.")
414
        .SetDefault("bilinear");
415 416
    AddAttr<bool>(
        "align_corners",
T
Tink_Y 已提交
417
        "an optional bool. Defaults to True. "
418 419
        "If True, the centers of 4 corner pixels of the input and output "
        "tensors are aligned, preserving the values at the corner pixels, "
T
Tink_Y 已提交
420
        "If False, are not aligned")
421 422
        .SetDefault(true);
    AddAttr<int>("align_mode",
T
Tink_Y 已提交
423
                 "(int, default \'1\'), optional for bilinear interpolation, "
T
tink2123 已提交
424 425
                 "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
                 "can be \'1\' for src_idx = scale*dst_index .")
T
tink2123 已提交
426
        .SetDefault(1);
427 428 429
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
430
    AddComment(R"DOC(
431 432 433
          This operator samples input X to given output shape by using specified
          interpolation method, the interpolation methods can be \"nearest\"
          for nearest neighbor interpolation and \"bilinear\" for bilinear 
434
          interpolation and \"linear\" for linear interpolation..
435

436
          Nearest neighbor interpolation is to perform nearest neighbor interpolation
T
tianshuo78520a 已提交
437
          in both the 3rd dimension(in height direction) and the 4th dimension(in width 
438
          direction) on input tensor.
439 440 441 442
           
          Linear interpolation is the method of using a line connecting two known quantities 
          to determine the value of an unknown quantity between the two known quantities. 
          
443 444 445 446 447 448
          Bilinear interpolation is an extension of linear interpolation for 
          interpolating functions of two variables (e.g. H-direction and 
          W-direction in this op) on a rectilinear 2D grid. The key idea is 
          to perform linear interpolation first in one direction, and then 
          again in the other direction.

K
Kaipeng Deng 已提交
449 450 451 452 453
          Trilinear interpolation is an extension of linear interpolation for 
          interpolating functions of three variables (e.g. D-direction, 
          H-direction and W-direction in this op) on a rectilinear 3D grid. 
          The linear interpolation is performed on three directions.

X
xiaoting 已提交
454 455 456 457 458
          Bicubic interpolation is an extension of cubic interpolation for interpolating
          data points on a two-dimensional regular grid. The interpolated surface is
          smoother than corresponding surfaces obtained by bilinear interpolation or
          nearest-neighbor interpolation.

T
tianshuo78520a 已提交
459
          Align_corners and align_mode are optional parameters,the calculation method 
460 461 462 463
          of interpolation can be selected by them.
          
          Example:

T
tink2123 已提交
464
          For scale:
465 466 467 468 469 470 471 472 473 474 475 476
          
            if align_corners = True and out_{size}>1 :

              scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0)
            
            else:
              
              scale_{factor} = float(in_{size}/out_{size})
            
          
          Nearest neighbor interpolation:
          
T
tink2123 已提交
477
          if:
478 479 480 481 482 483 484 485
              align_corners = False

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor
              W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor

T
tink2123 已提交
486
          else:
487 488 489 490 491 492 493 494 495 496
              align_corners = True

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})

          Bilinear interpolation:

T
tink2123 已提交
497
          if:
498 499 500 501 502 503 504 505 506
              align_corners = False , align_mode = 0
              
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


T
tink2123 已提交
507
          else:
508 509 510 511 512 513 514
           
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

K
Kaipeng Deng 已提交
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
          Trilinear interpolation:

          if:
              align_corners = False , align_mode = 0
              
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              
              D_out = (D_{in}+0.5) * scale_{factor} - 0.5
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


          else:
           
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:

              D_out = D_{in} * scale_{factor}
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
X
xiaoting 已提交
536 537 538 539 540 541 542 543 544 545 546 547 548 549

          Bicubic interpolation:

          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
550

551
          For details of nearest neighbor interpolation, please refer to Wikipedia: 
552
          https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
553 554 555

          For details of bilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Bilinear_interpolation
K
Kaipeng Deng 已提交
556 557 558

          For details of trilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Trilinear_interpolation
X
xiaoting 已提交
559 560 561

          For details of bicubic interpolation, please refer to Wikipedia:
          https://en.wikipedia.org/wiki/Bicubic_interpolation
562 563 564 565
         )DOC");
  }
};

566
class InterpolateOpGrad : public framework::OperatorWithKernel {
567 568 569 570 571
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
572 573 574 575
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InterpolateGrad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   "Out@GRAD", "InterpolateGrad");

576 577 578 579 580 581 582 583
    auto dim_x = ctx->GetInputDim("X");
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
584 585 586
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.GetPlace());
587
  }
588 589 590 591 592 593 594 595 596 597

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
598 599
};

H
hong 已提交
600 601
template <typename T>
class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
602
 public:
H
hong 已提交
603
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
604 605

 protected:
606
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
607 608 609 610
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("X", this->Input("X"));
    if (this->HasInput("SizeTensor") > 0) {
      op->SetInput("SizeTensor", this->Input("SizeTensor"));
611
    }
H
hong 已提交
612 613
    if (this->HasInput("OutSize") > 0) {
      op->SetInput("OutSize", this->Input("OutSize"));
S
sneaxiy 已提交
614
    }
H
hong 已提交
615 616
    if (this->HasInput("Scale") > 0) {
      op->SetInput("Scale", this->Input("Scale"));
617
    }
H
hong 已提交
618 619 620
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
621 622 623
  }
};

624
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInferer,
625
                                    "X");
S
sneaxiy 已提交
626

627 628 629 630
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
631
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
H
hong 已提交
632 633
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
S
sneaxiy 已提交
634
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
635
                  ops::InterpolateGradNoNeedBufferVarsInferer);
636
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
H
hong 已提交
637 638
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
S
sneaxiy 已提交
639
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
640
                  ops::InterpolateGradNoNeedBufferVarsInferer);
K
Kaipeng Deng 已提交
641
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
H
hong 已提交
642 643
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
K
Kaipeng Deng 已提交
644
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
645
                  ops::InterpolateGradNoNeedBufferVarsInferer);
X
xiaoting 已提交
646 647 648 649
REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad,
650
                  ops::InterpolateGradNoNeedBufferVarsInferer);
651 652 653 654 655 656
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
657 658
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
659
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
660
                       ops::InterpolateGradKernel<double>);
K
Kaipeng Deng 已提交
661 662 663 664 665
REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
666 667 668 669
REGISTER_OPERATOR(linear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_interp_grad, ops::InterpolateOpGrad,
670
                  ops::InterpolateGradNoNeedBufferVarsInferer);
671 672 673 674 675
REGISTER_OP_CPU_KERNEL(linear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(linear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
X
xiaoting 已提交
676 677 678 679
REGISTER_OP_CPU_KERNEL(bicubic_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);