interpolate_op.cc 27.0 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

12
#include "paddle/fluid/operators/interpolate_op.h"
13

S
sneaxiy 已提交
14
#include <memory>
15
#include <string>
16
#include <vector>
17

18
#include "paddle/fluid/framework/op_registry.h"
19 20 21
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
22 23 24 25 26

namespace paddle {
namespace operators {

using framework::Tensor;
27
using DataLayout = framework::DataLayout;
28

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

  PADDLE_ENFORCE_EQ("linear", interp_method,
                    platform::errors::InvalidArgument(
                        "Interpolation method can only be \"linear\" when"
                        "Input(X) dimension is 3, but got method = %s .",
                        interp_method));
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));

  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 1,
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'size of Op(interpolate) must be 1. "
            "Attr(out_shape)'s length must be 1 for 3-D input tensor, but got "
            "size = %d .",
            inputs_name.size()));
    int out_w = ctx->Attrs().Get<int>("out_w");
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_w};
    } else {
      dim_out = {dim_x[0], out_w, dim_x[2]};
    }
    ctx->SetOutputDim("Out", dim_out);

    return;
  }

  int out_w;
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
    PADDLE_ENFORCE_EQ(
        scale_tensor.size(), 1,
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got dimension = %d .",
            scale_tensor.size()));
    out_w = -1;
  } else {
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      // protect when input shape is -1
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_w = ctx->Attrs().Get<int>("out_w");
    }
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
    PADDLE_ENFORCE_EQ(
        out_size_dim.size(), 1,
        platform::errors::InvalidArgument(
            "OutSize's dimension size must be 1, but got dimention = %d .",
            out_size_dim.size()));
K
Kqnonrime 已提交
93 94 95 96 97
    PADDLE_ENFORCE_EQ(
        out_size_dim[0], 1,
        platform::errors::InvalidArgument(
            "OutSize's 0-th dimension's value must be 1, but got value = %d .",
            out_size_dim[0]));
98 99 100 101 102 103 104 105 106 107 108 109 110
    ctx->ShareLoD("X", "Out");
    return;
  }

  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_w};
  } else {
    dim_out = {dim_x[0], out_w, dim_x[2]};
  }
  ctx->SetOutputDim("Out", dim_out);
}

K
Kaipeng Deng 已提交
111 112 113 114
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

115 116
  PADDLE_ENFORCE_EQ("bilinear" == interp_method || "nearest" == interp_method ||
                        "bicubic" == interp_method,
117 118 119 120 121 122
                    true,
                    platform::errors::InvalidArgument(
                        "Interpolation method can only be \"bilinear\" "
                        "or \"nearest\" or \"bicubic\" when "
                        "Input(X) dimension is 4, but got method is %s.",
                        interp_method));
123 124
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
125

126 127 128 129 130
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 2,
131 132 133 134 135
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'size of Op(interpolate) must be 2. "
            "Attr(out_shape)'s length must be 2 for 4-D input "
            "tensor, but got size = %d .",
            inputs_name.size()));
136 137
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
138 139 140 141 142 143 144
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
    }
    ctx->SetOutputDim("Out", dim_out);
145 146 147 148

    return;
  }

K
Kaipeng Deng 已提交
149
  int out_h, out_w;
150 151
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
152 153 154 155 156
    PADDLE_ENFORCE_EQ(
        scale_tensor.size(), 1,
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got dimension = %d .",
            scale_tensor.size()));
157 158
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
159
  } else {
160 161 162
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
163 164 165 166 167 168
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
169 170 171 172 173 174 175
      // protect when input shape is -1
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
176 177 178 179
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
180 181
    PADDLE_ENFORCE_EQ(
        out_size_dim.size(), 1,
182 183 184
        platform::errors::InvalidArgument("OutSize's dimension size must be 1, "
                                          "but got dimension size is %d .",
                                          out_size_dim.size()));
185 186 187
    PADDLE_ENFORCE_EQ(
        out_size_dim[0], 2,
        platform::errors::InvalidArgument(
188
            "OutSize's dimension[0] must be 2, but got dimension[0] is %d .",
189
            out_size_dim[0]));
K
Kaipeng Deng 已提交
190 191 192 193
    ctx->ShareLoD("X", "Out");
    return;
  }

194 195 196 197 198 199 200
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
201 202 203 204 205 206
}

static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

207 208 209 210 211 212
  PADDLE_ENFORCE_EQ(
      "trilinear", interp_method,
      platform::errors::InvalidArgument(
          "Interpolation method can only be \"trilinear\" when Input(X) "
          "dimension is 5, but got method = %s .",
          interp_method));
213 214
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
215

216 217 218 219 220
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 3,
221 222 223 224 225
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'s size of Op(interpolate) must be 3. "
            "Attr(out_shape)'s length must be 3 for 5-D input "
            "tensor, but got size = %d .",
            inputs_name.size()));
226 227 228
    int out_d = ctx->Attrs().Get<int>("out_d");
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
229 230 231 232 233 234 235
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
    }
    ctx->SetOutputDim("Out", dim_out);
236 237 238 239

    return;
  }

K
Kaipeng Deng 已提交
240
  int out_d, out_h, out_w;
241 242
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
243 244 245 246 247
    PADDLE_ENFORCE_EQ(
        scale_tensor.size(), 1,
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got size = %d .",
            scale_tensor.size()));
248 249 250
    out_d = -1;
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
251
  } else {
252 253 254
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
255 256 257 258 259 260 261 262 263
      out_d = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[4] * scale)
                   : static_cast<int>(dim_x[3] * scale));
264 265 266 267 268 269 270 271 272
      // protect when input shape is -1
      out_d = out_d > 0 ? out_d : -1;
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_d = ctx->Attrs().Get<int>("out_d");
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
273 274 275 276
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
277 278 279 280 281
    PADDLE_ENFORCE_EQ(
        out_size_dim.size(), 1,
        platform::errors::InvalidArgument(
            "OutSize's dimension size must be 1, but got size is %d.",
            out_size_dim.size()));
282
    PADDLE_ENFORCE_EQ(out_size_dim[0], 3,
283 284 285
                      platform::errors::InvalidArgument(
                          "OutSize's dim[0] must be 3, but got size is %d.",
                          out_size_dim[0]));
K
Kaipeng Deng 已提交
286 287 288 289
    ctx->ShareLoD("X", "Out");
    return;
  }

290 291 292 293 294 295 296
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
297 298
}

299
class InterpolateOp : public framework::OperatorWithKernel {
300 301 302 303 304
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
305 306
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Interpolate");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Interpolate");
307

308
    auto dim_x = ctx->GetInputDim("X");  // NCHW format
309 310 311 312 313 314 315 316 317
    PADDLE_ENFORCE(
        dim_x.size() == 3 || dim_x.size() == 4 || dim_x.size() == 5,
        platform::errors::Unimplemented(
            "Input(X) dimension must be 3, 4 or 5, but got dimension = %d .",
            dim_x.size()));
    if (dim_x.size() == 3) {
      // shape check for 1D interpolate for input tensor shape NCHW
      Interpolate1DInferShapeCheck(ctx);
    } else if (dim_x.size() == 4) {
K
Kaipeng Deng 已提交
318 319 320 321 322
      // shape check for 2D interpolate for input tensor shape NCHW
      Interpolate2DInferShapeCheck(ctx);
    } else {  // dim_x.size() == 5
      // shape check for 3D interpolate for input tensor shape NCDHW
      Interpolate3DInferShapeCheck(ctx);
323 324 325 326 327 328
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
329 330
    framework::DataLayout layout = framework::DataLayout::kAnyLayout;
    framework::LibraryType library = framework::LibraryType::kPlain;
331
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
332 333

#ifdef PADDLE_WITH_MKLDNN
334
    const auto& interp_method = ctx.Attr<std::string>("interp_method");
335
    // TODO(danqing): support other interp_method
336
    if (this->CanMKLDNNBeUsed(ctx, data_type) &&
337 338 339 340 341 342
        (interp_method == "nearest" || interp_method == "bilinear")) {
      layout = framework::DataLayout::kMKLDNN;
      library = framework::LibraryType::kMKLDNN;
    }
#endif

343
    return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
344
  }
345 346 347 348

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
#ifdef PADDLE_WITH_MKLDNN
    if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
        (tensor.layout() != framework::DataLayout::kMKLDNN)) {
      auto attrs = Attrs();
      auto ar = paddle::framework::AttrReader(attrs);
      const std::string data_format = ar.Get<std::string>("data_layout");
      auto dl = framework::StringToDataLayout(data_format);
      // Some models may have intentionally set "AnyLayout" for pool
      // op. Treat this as NCHW (default data_format value)
      if (dl != framework::DataLayout::kAnyLayout) {
        return framework::OpKernelType(expected_kernel_type.data_type_,
                                       tensor.place(), dl);
      }
    }
#endif
364 365 366 367 368 369
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
370 371
};

372
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
373 374 375
 public:
  void Make() override {
    AddInput("X",
376
             "The input tensor of interpolate operator, "
K
Kaipeng Deng 已提交
377 378
             "This is a 4-D tensor with shape of [N, C, H, W] or a "
             "5-D tensor with shape of [N, C, D, H, W].");
379
    AddInput("OutSize",
380
             "This is a 1-D tensor with two numbers to specify output size. "
K
Kaipeng Deng 已提交
381 382
             "It should be [output_height, output_width] when input is a 4-D "
             "tensor and should be [output_depth, output_height, output_width] "
383 384 385 386 387 388 389 390 391 392 393 394 395
             "when input is a 5-D tensor. It has a higher priority than "
             "the attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDispensable();
    AddInput("SizeTensor",
             "(vector<Tensor<int32>>, optional). If provided, interpolate will "
             "use this. The shape of the tensor in vector MUST BE [1]. "
             "It has the highest priority compare with Input(OutSize) and "
             "attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDuplicable()
        .AsDispensable();
    AddInput("Scale",
             "This is a 1-D tensor with one number to specify output scale. "
             "It has the higher priority compare with attr(scale).")
396
        .AsDispensable();
397 398
    AddOutput("Out",
              "The output tensor of interpolate operator, "
K
Kaipeng Deng 已提交
399
              "This is a tensor in same rank with Input(X).");
400

401 402 403 404 405 406 407
    AddAttr<std::string>(
        "data_layout",
        "(string, default NCHW) Only used in "
        "an optional string from: \"NHWC\", \"NCHW\". "
        "Specify that the data format of the input and output data is "
        "channel_first or channel_last.")
        .SetDefault("NCHW");
K
Kaipeng Deng 已提交
408 409 410
    AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
    AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
    AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
D
dengkaipeng 已提交
411
    AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
412 413
    AddAttr<std::string>("interp_method",
                         "(string, default \"bilinear\"), interpolation "
414 415
                         "method, can be \"linear\" for linear interpolation"
                         ",\"bilinear\" for "
K
Kaipeng Deng 已提交
416 417
                         "bilinear interpolation, \"trilinear\" for trilinear "
                         "interpolation and \"nearest\" for nearest "
X
xiaoting 已提交
418 419
                         "neighbor interpolation, and \"bicubic\" for bicubic"
                         "interpolation.")
420
        .SetDefault("bilinear");
421 422
    AddAttr<bool>(
        "align_corners",
T
Tink_Y 已提交
423
        "an optional bool. Defaults to True. "
424 425
        "If True, the centers of 4 corner pixels of the input and output "
        "tensors are aligned, preserving the values at the corner pixels, "
T
Tink_Y 已提交
426
        "If False, are not aligned")
427 428
        .SetDefault(true);
    AddAttr<int>("align_mode",
T
Tink_Y 已提交
429
                 "(int, default \'1\'), optional for bilinear interpolation, "
T
tink2123 已提交
430 431
                 "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
                 "can be \'1\' for src_idx = scale*dst_index .")
T
tink2123 已提交
432
        .SetDefault(1);
433 434
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
435 436
        .SetDefault(false)
        .AsExtra();
437
    AddComment(R"DOC(
438 439 440
          This operator samples input X to given output shape by using specified
          interpolation method, the interpolation methods can be \"nearest\"
          for nearest neighbor interpolation and \"bilinear\" for bilinear 
441
          interpolation and \"linear\" for linear interpolation..
442

443
          Nearest neighbor interpolation is to perform nearest neighbor interpolation
T
tianshuo78520a 已提交
444
          in both the 3rd dimension(in height direction) and the 4th dimension(in width 
445
          direction) on input tensor.
446 447 448 449
           
          Linear interpolation is the method of using a line connecting two known quantities 
          to determine the value of an unknown quantity between the two known quantities. 
          
450 451 452 453 454 455
          Bilinear interpolation is an extension of linear interpolation for 
          interpolating functions of two variables (e.g. H-direction and 
          W-direction in this op) on a rectilinear 2D grid. The key idea is 
          to perform linear interpolation first in one direction, and then 
          again in the other direction.

K
Kaipeng Deng 已提交
456 457 458 459 460
          Trilinear interpolation is an extension of linear interpolation for 
          interpolating functions of three variables (e.g. D-direction, 
          H-direction and W-direction in this op) on a rectilinear 3D grid. 
          The linear interpolation is performed on three directions.

X
xiaoting 已提交
461 462 463 464 465
          Bicubic interpolation is an extension of cubic interpolation for interpolating
          data points on a two-dimensional regular grid. The interpolated surface is
          smoother than corresponding surfaces obtained by bilinear interpolation or
          nearest-neighbor interpolation.

T
tianshuo78520a 已提交
466
          Align_corners and align_mode are optional parameters,the calculation method 
467 468 469 470
          of interpolation can be selected by them.
          
          Example:

T
tink2123 已提交
471
          For scale:
472 473 474 475 476 477 478 479 480 481 482 483
          
            if align_corners = True and out_{size}>1 :

              scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0)
            
            else:
              
              scale_{factor} = float(in_{size}/out_{size})
            
          
          Nearest neighbor interpolation:
          
T
tink2123 已提交
484
          if:
485 486 487 488 489 490 491 492
              align_corners = False

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor
              W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor

T
tink2123 已提交
493
          else:
494 495 496 497 498 499 500 501 502 503
              align_corners = True

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})

          Bilinear interpolation:

T
tink2123 已提交
504
          if:
505 506 507 508 509 510 511 512 513
              align_corners = False , align_mode = 0
              
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


T
tink2123 已提交
514
          else:
515 516 517 518 519 520 521
           
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

K
Kaipeng Deng 已提交
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
          Trilinear interpolation:

          if:
              align_corners = False , align_mode = 0
              
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              
              D_out = (D_{in}+0.5) * scale_{factor} - 0.5
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


          else:
           
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:

              D_out = D_{in} * scale_{factor}
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
X
xiaoting 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555 556

          Bicubic interpolation:

          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
557

558
          For details of nearest neighbor interpolation, please refer to Wikipedia: 
559
          https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
560 561 562

          For details of bilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Bilinear_interpolation
K
Kaipeng Deng 已提交
563 564 565

          For details of trilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Trilinear_interpolation
X
xiaoting 已提交
566 567 568

          For details of bicubic interpolation, please refer to Wikipedia:
          https://en.wikipedia.org/wiki/Bicubic_interpolation
569 570 571 572
         )DOC");
  }
};

573
class InterpolateOpGrad : public framework::OperatorWithKernel {
574 575 576 577 578
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
579 580 581 582
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InterpolateGrad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   "Out@GRAD", "InterpolateGrad");

583 584 585 586 587 588 589 590
    auto dim_x = ctx->GetInputDim("X");
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
591 592 593
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.GetPlace());
594
  }
595 596 597 598 599 600 601 602 603 604

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
605 606
};

H
hong 已提交
607 608
template <typename T>
class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
609
 public:
H
hong 已提交
610
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
611 612

 protected:
613
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
614 615 616 617
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("X", this->Input("X"));
    if (this->HasInput("SizeTensor") > 0) {
      op->SetInput("SizeTensor", this->Input("SizeTensor"));
618
    }
H
hong 已提交
619 620
    if (this->HasInput("OutSize") > 0) {
      op->SetInput("OutSize", this->Input("OutSize"));
S
sneaxiy 已提交
621
    }
H
hong 已提交
622 623
    if (this->HasInput("Scale") > 0) {
      op->SetInput("Scale", this->Input("Scale"));
624
    }
H
hong 已提交
625 626 627
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
628 629 630
  }
};

631
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInferer,
632
                                    "X");
S
sneaxiy 已提交
633

634 635 636 637
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
638
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
H
hong 已提交
639 640
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
S
sneaxiy 已提交
641
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
642
                  ops::InterpolateGradNoNeedBufferVarsInferer);
643
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
H
hong 已提交
644 645
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
S
sneaxiy 已提交
646
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
647
                  ops::InterpolateGradNoNeedBufferVarsInferer);
K
Kaipeng Deng 已提交
648
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
H
hong 已提交
649 650
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
K
Kaipeng Deng 已提交
651
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
652
                  ops::InterpolateGradNoNeedBufferVarsInferer);
X
xiaoting 已提交
653 654 655 656
REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad,
657
                  ops::InterpolateGradNoNeedBufferVarsInferer);
658 659 660 661 662 663
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
664 665
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
666
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
667
                       ops::InterpolateGradKernel<double>);
K
Kaipeng Deng 已提交
668 669 670 671 672
REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
673 674 675 676
REGISTER_OPERATOR(linear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_interp_grad, ops::InterpolateOpGrad,
677
                  ops::InterpolateGradNoNeedBufferVarsInferer);
678 679 680 681 682
REGISTER_OP_CPU_KERNEL(linear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(linear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
X
xiaoting 已提交
683 684 685 686
REGISTER_OP_CPU_KERNEL(bicubic_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);