interpolate_op.cc 26.9 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

12
#include "paddle/fluid/operators/interpolate_op.h"
13

S
sneaxiy 已提交
14
#include <memory>
15
#include <string>
16
#include <vector>
17

18
#include "paddle/fluid/framework/op_registry.h"
19 20 21
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
22 23 24 25

namespace paddle {
namespace operators {

26
using DataLayout = framework::DataLayout;
27

28 29 30 31
static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

32 33
  PADDLE_ENFORCE_EQ("linear",
                    interp_method,
34 35 36 37 38 39 40 41 42 43 44
                    platform::errors::InvalidArgument(
                        "Interpolation method can only be \"linear\" when"
                        "Input(X) dimension is 3, but got method = %s .",
                        interp_method));
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));

  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
45 46
        inputs_name.size(),
        1,
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'size of Op(interpolate) must be 1. "
            "Attr(out_shape)'s length must be 1 for 3-D input tensor, but got "
            "size = %d .",
            inputs_name.size()));
    int out_w = ctx->Attrs().Get<int>("out_w");
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_w};
    } else {
      dim_out = {dim_x[0], out_w, dim_x[2]};
    }
    ctx->SetOutputDim("Out", dim_out);

    return;
  }

  int out_w;
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
    PADDLE_ENFORCE_EQ(
68 69
        scale_tensor.size(),
        1,
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got dimension = %d .",
            scale_tensor.size()));
    out_w = -1;
  } else {
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      // protect when input shape is -1
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_w = ctx->Attrs().Get<int>("out_w");
    }
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
    PADDLE_ENFORCE_EQ(
91 92
        out_size_dim.size(),
        1,
93 94 95
        platform::errors::InvalidArgument(
            "OutSize's dimension size must be 1, but got dimention = %d .",
            out_size_dim.size()));
K
Kqnonrime 已提交
96
    PADDLE_ENFORCE_EQ(
97 98
        out_size_dim[0],
        1,
K
Kqnonrime 已提交
99 100 101
        platform::errors::InvalidArgument(
            "OutSize's 0-th dimension's value must be 1, but got value = %d .",
            out_size_dim[0]));
102 103 104 105 106 107 108 109 110 111 112 113 114
    ctx->ShareLoD("X", "Out");
    return;
  }

  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_w};
  } else {
    dim_out = {dim_x[0], out_w, dim_x[2]};
  }
  ctx->SetOutputDim("Out", dim_out);
}

K
Kaipeng Deng 已提交
115 116 117 118
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

119 120
  PADDLE_ENFORCE_EQ("bilinear" == interp_method || "nearest" == interp_method ||
                        "bicubic" == interp_method,
121 122 123 124 125 126
                    true,
                    platform::errors::InvalidArgument(
                        "Interpolation method can only be \"bilinear\" "
                        "or \"nearest\" or \"bicubic\" when "
                        "Input(X) dimension is 4, but got method is %s.",
                        interp_method));
127 128
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
129

130 131 132 133
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
134 135
        inputs_name.size(),
        2,
136 137 138 139 140
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'size of Op(interpolate) must be 2. "
            "Attr(out_shape)'s length must be 2 for 4-D input "
            "tensor, but got size = %d .",
            inputs_name.size()));
141 142
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
143 144 145 146 147 148 149
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
    }
    ctx->SetOutputDim("Out", dim_out);
150 151 152 153

    return;
  }

K
Kaipeng Deng 已提交
154
  int out_h, out_w;
155 156
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
157
    PADDLE_ENFORCE_EQ(
158 159
        scale_tensor.size(),
        1,
160 161 162
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got dimension = %d .",
            scale_tensor.size()));
163 164
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
165
  } else {
166 167 168
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
169 170 171 172 173 174
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
175 176 177 178 179 180 181
      // protect when input shape is -1
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
182 183 184 185
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
186
    PADDLE_ENFORCE_EQ(
187 188
        out_size_dim.size(),
        1,
189 190 191
        platform::errors::InvalidArgument("OutSize's dimension size must be 1, "
                                          "but got dimension size is %d .",
                                          out_size_dim.size()));
192
    PADDLE_ENFORCE_EQ(
193 194
        out_size_dim[0],
        2,
195
        platform::errors::InvalidArgument(
196
            "OutSize's dimension[0] must be 2, but got dimension[0] is %d .",
197
            out_size_dim[0]));
K
Kaipeng Deng 已提交
198 199 200 201
    ctx->ShareLoD("X", "Out");
    return;
  }

202 203 204 205 206 207 208
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
209 210 211 212 213 214
}

static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

215
  PADDLE_ENFORCE_EQ(
216 217
      "trilinear",
      interp_method,
218 219 220 221
      platform::errors::InvalidArgument(
          "Interpolation method can only be \"trilinear\" when Input(X) "
          "dimension is 5, but got method = %s .",
          interp_method));
222 223
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
224

225 226 227 228
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
229 230
        inputs_name.size(),
        3,
231 232 233 234 235
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'s size of Op(interpolate) must be 3. "
            "Attr(out_shape)'s length must be 3 for 5-D input "
            "tensor, but got size = %d .",
            inputs_name.size()));
236 237 238
    int out_d = ctx->Attrs().Get<int>("out_d");
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
239 240 241 242 243 244 245
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
    }
    ctx->SetOutputDim("Out", dim_out);
246 247 248 249

    return;
  }

K
Kaipeng Deng 已提交
250
  int out_d, out_h, out_w;
251 252
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
253
    PADDLE_ENFORCE_EQ(
254 255
        scale_tensor.size(),
        1,
256 257 258
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got size = %d .",
            scale_tensor.size()));
259 260 261
    out_d = -1;
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
262
  } else {
263 264 265
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
266 267 268 269 270 271 272 273 274
      out_d = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[4] * scale)
                   : static_cast<int>(dim_x[3] * scale));
275 276 277 278 279 280 281 282 283
      // protect when input shape is -1
      out_d = out_d > 0 ? out_d : -1;
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_d = ctx->Attrs().Get<int>("out_d");
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
284 285 286 287
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
288
    PADDLE_ENFORCE_EQ(
289 290
        out_size_dim.size(),
        1,
291 292 293
        platform::errors::InvalidArgument(
            "OutSize's dimension size must be 1, but got size is %d.",
            out_size_dim.size()));
294 295
    PADDLE_ENFORCE_EQ(out_size_dim[0],
                      3,
296 297 298
                      platform::errors::InvalidArgument(
                          "OutSize's dim[0] must be 3, but got size is %d.",
                          out_size_dim[0]));
K
Kaipeng Deng 已提交
299 300 301 302
    ctx->ShareLoD("X", "Out");
    return;
  }

303 304 305 306 307 308 309
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
310 311
}

312
class InterpolateOp : public framework::OperatorWithKernel {
313 314 315 316 317
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
318 319
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Interpolate");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Interpolate");
320

321
    auto dim_x = ctx->GetInputDim("X");  // NCHW format
322 323 324 325 326 327 328 329 330
    PADDLE_ENFORCE(
        dim_x.size() == 3 || dim_x.size() == 4 || dim_x.size() == 5,
        platform::errors::Unimplemented(
            "Input(X) dimension must be 3, 4 or 5, but got dimension = %d .",
            dim_x.size()));
    if (dim_x.size() == 3) {
      // shape check for 1D interpolate for input tensor shape NCHW
      Interpolate1DInferShapeCheck(ctx);
    } else if (dim_x.size() == 4) {
K
Kaipeng Deng 已提交
331 332 333 334 335
      // shape check for 2D interpolate for input tensor shape NCHW
      Interpolate2DInferShapeCheck(ctx);
    } else {  // dim_x.size() == 5
      // shape check for 3D interpolate for input tensor shape NCDHW
      Interpolate3DInferShapeCheck(ctx);
336 337 338 339 340 341
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
342
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
J
jiahongyu 已提交
343
    return framework::OpKernelType(data_type, ctx.GetPlace());
344
  }
345 346

  framework::OpKernelType GetKernelTypeForVar(
347
      const std::string& var_name,
348
      const phi::DenseTensor& tensor,
349
      const framework::OpKernelType& expected_kernel_type) const override {
350 351 352 353 354 355 356 357 358 359
#ifdef PADDLE_WITH_MKLDNN
    if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
        (tensor.layout() != framework::DataLayout::kMKLDNN)) {
      auto attrs = Attrs();
      auto ar = paddle::framework::AttrReader(attrs);
      const std::string data_format = ar.Get<std::string>("data_layout");
      auto dl = framework::StringToDataLayout(data_format);
      // Some models may have intentionally set "AnyLayout" for pool
      // op. Treat this as NCHW (default data_format value)
      if (dl != framework::DataLayout::kAnyLayout) {
360 361
        return framework::OpKernelType(
            expected_kernel_type.data_type_, tensor.place(), dl);
362 363 364
      }
    }
#endif
365 366 367
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
368 369
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
370
  }
371 372
};

373
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
374 375 376
 public:
  void Make() override {
    AddInput("X",
377
             "The input tensor of interpolate operator, "
K
Kaipeng Deng 已提交
378 379
             "This is a 4-D tensor with shape of [N, C, H, W] or a "
             "5-D tensor with shape of [N, C, D, H, W].");
380
    AddInput("OutSize",
381
             "This is a 1-D tensor with two numbers to specify output size. "
K
Kaipeng Deng 已提交
382 383
             "It should be [output_height, output_width] when input is a 4-D "
             "tensor and should be [output_depth, output_height, output_width] "
384 385 386 387 388 389 390 391 392 393 394 395 396
             "when input is a 5-D tensor. It has a higher priority than "
             "the attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDispensable();
    AddInput("SizeTensor",
             "(vector<Tensor<int32>>, optional). If provided, interpolate will "
             "use this. The shape of the tensor in vector MUST BE [1]. "
             "It has the highest priority compare with Input(OutSize) and "
             "attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDuplicable()
        .AsDispensable();
    AddInput("Scale",
             "This is a 1-D tensor with one number to specify output scale. "
             "It has the higher priority compare with attr(scale).")
397
        .AsDispensable();
398 399
    AddOutput("Out",
              "The output tensor of interpolate operator, "
K
Kaipeng Deng 已提交
400
              "This is a tensor in same rank with Input(X).");
401

402 403 404 405 406 407 408
    AddAttr<std::string>(
        "data_layout",
        "(string, default NCHW) Only used in "
        "an optional string from: \"NHWC\", \"NCHW\". "
        "Specify that the data format of the input and output data is "
        "channel_first or channel_last.")
        .SetDefault("NCHW");
K
Kaipeng Deng 已提交
409 410 411
    AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
    AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
    AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
D
dengkaipeng 已提交
412
    AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
413 414
    AddAttr<std::string>("interp_method",
                         "(string, default \"bilinear\"), interpolation "
415 416
                         "method, can be \"linear\" for linear interpolation"
                         ",\"bilinear\" for "
K
Kaipeng Deng 已提交
417 418
                         "bilinear interpolation, \"trilinear\" for trilinear "
                         "interpolation and \"nearest\" for nearest "
X
xiaoting 已提交
419 420
                         "neighbor interpolation, and \"bicubic\" for bicubic"
                         "interpolation.")
421
        .SetDefault("bilinear");
422 423
    AddAttr<bool>(
        "align_corners",
T
Tink_Y 已提交
424
        "an optional bool. Defaults to True. "
425 426
        "If True, the centers of 4 corner pixels of the input and output "
        "tensors are aligned, preserving the values at the corner pixels, "
T
Tink_Y 已提交
427
        "If False, are not aligned")
428 429
        .SetDefault(true);
    AddAttr<int>("align_mode",
T
Tink_Y 已提交
430
                 "(int, default \'1\'), optional for bilinear interpolation, "
T
tink2123 已提交
431 432
                 "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
                 "can be \'1\' for src_idx = scale*dst_index .")
T
tink2123 已提交
433
        .SetDefault(1);
434 435
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
436 437
        .SetDefault(false)
        .AsExtra();
438
    AddComment(R"DOC(
439 440
          This operator samples input X to given output shape by using specified
          interpolation method, the interpolation methods can be \"nearest\"
441
          for nearest neighbor interpolation and \"bilinear\" for bilinear
442
          interpolation and \"linear\" for linear interpolation..
443

444
          Nearest neighbor interpolation is to perform nearest neighbor interpolation
445
          in both the 3rd dimension(in height direction) and the 4th dimension(in width
446
          direction) on input tensor.
447 448 449 450 451 452 453 454

          Linear interpolation is the method of using a line connecting two known quantities
          to determine the value of an unknown quantity between the two known quantities.

          Bilinear interpolation is an extension of linear interpolation for
          interpolating functions of two variables (e.g. H-direction and
          W-direction in this op) on a rectilinear 2D grid. The key idea is
          to perform linear interpolation first in one direction, and then
455 456
          again in the other direction.

457 458 459
          Trilinear interpolation is an extension of linear interpolation for
          interpolating functions of three variables (e.g. D-direction,
          H-direction and W-direction in this op) on a rectilinear 3D grid.
K
Kaipeng Deng 已提交
460 461
          The linear interpolation is performed on three directions.

X
xiaoting 已提交
462 463 464 465 466
          Bicubic interpolation is an extension of cubic interpolation for interpolating
          data points on a two-dimensional regular grid. The interpolated surface is
          smoother than corresponding surfaces obtained by bilinear interpolation or
          nearest-neighbor interpolation.

467
          Align_corners and align_mode are optional parameters,the calculation method
468
          of interpolation can be selected by them.
469

470 471
          Example:

T
tink2123 已提交
472
          For scale:
473

474 475 476
            if align_corners = True and out_{size}>1 :

              scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0)
477

478
            else:
479

480
              scale_{factor} = float(in_{size}/out_{size})
481 482


483
          Nearest neighbor interpolation:
484

T
tink2123 已提交
485
          if:
486 487 488 489 490 491 492 493
              align_corners = False

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor
              W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor

T
tink2123 已提交
494
          else:
495 496 497 498 499 500 501 502 503 504
              align_corners = True

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})

          Bilinear interpolation:

T
tink2123 已提交
505
          if:
506
              align_corners = False , align_mode = 0
507

508 509
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
510

511 512 513 514
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


T
tink2123 已提交
515
          else:
516

517 518 519 520 521 522
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

K
Kaipeng Deng 已提交
523 524 525 526
          Trilinear interpolation:

          if:
              align_corners = False , align_mode = 0
527

K
Kaipeng Deng 已提交
528 529
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
530

K
Kaipeng Deng 已提交
531 532 533 534 535 536
              D_out = (D_{in}+0.5) * scale_{factor} - 0.5
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


          else:
537

K
Kaipeng Deng 已提交
538 539 540 541 542 543
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:

              D_out = D_{in} * scale_{factor}
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
X
xiaoting 已提交
544 545 546 547 548 549 550 551 552 553 554 555 556 557

          Bicubic interpolation:

          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
558

559
          For details of nearest neighbor interpolation, please refer to Wikipedia:
560
          https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
561

562
          For details of bilinear interpolation, please refer to Wikipedia:
563
          https://en.wikipedia.org/wiki/Bilinear_interpolation
K
Kaipeng Deng 已提交
564

565
          For details of trilinear interpolation, please refer to Wikipedia:
K
Kaipeng Deng 已提交
566
          https://en.wikipedia.org/wiki/Trilinear_interpolation
X
xiaoting 已提交
567 568 569

          For details of bicubic interpolation, please refer to Wikipedia:
          https://en.wikipedia.org/wiki/Bicubic_interpolation
570 571 572 573
         )DOC");
  }
};

574
class InterpolateOpGrad : public framework::OperatorWithKernel {
575 576 577 578 579
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
580
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InterpolateGrad");
581 582 583 584
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
                   "Input",
                   "Out@GRAD",
                   "InterpolateGrad");
585

586 587 588 589 590 591 592 593
    auto dim_x = ctx->GetInputDim("X");
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
594 595 596
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.GetPlace());
597
  }
598 599

  framework::OpKernelType GetKernelTypeForVar(
600
      const std::string& var_name,
601
      const phi::DenseTensor& tensor,
602 603 604 605
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
606 607
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
608
  }
609 610
};

H
hong 已提交
611 612
template <typename T>
class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
613
 public:
H
hong 已提交
614
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
615 616

 protected:
617
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
618 619 620 621
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("X", this->Input("X"));
    if (this->HasInput("SizeTensor") > 0) {
      op->SetInput("SizeTensor", this->Input("SizeTensor"));
622
    }
H
hong 已提交
623 624
    if (this->HasInput("OutSize") > 0) {
      op->SetInput("OutSize", this->Input("OutSize"));
S
sneaxiy 已提交
625
    }
H
hong 已提交
626 627
    if (this->HasInput("Scale") > 0) {
      op->SetInput("Scale", this->Input("Scale"));
628
    }
H
hong 已提交
629 630 631
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
632 633 634
  }
};

635
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInferer,
636
                                    "X");
S
sneaxiy 已提交
637

638 639 640 641
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
642 643 644
REGISTER_OPERATOR(bilinear_interp,
                  ops::InterpolateOp,
                  ops::InterpolateOpMaker,
H
hong 已提交
645 646
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
647 648
REGISTER_OPERATOR(bilinear_interp_grad,
                  ops::InterpolateOpGrad,
649
                  ops::InterpolateGradNoNeedBufferVarsInferer);
650 651 652
REGISTER_OPERATOR(nearest_interp,
                  ops::InterpolateOp,
                  ops::InterpolateOpMaker,
H
hong 已提交
653 654
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
655 656
REGISTER_OPERATOR(nearest_interp_grad,
                  ops::InterpolateOpGrad,
657
                  ops::InterpolateGradNoNeedBufferVarsInferer);
658 659 660
REGISTER_OPERATOR(trilinear_interp,
                  ops::InterpolateOp,
                  ops::InterpolateOpMaker,
H
hong 已提交
661 662
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
663 664
REGISTER_OPERATOR(trilinear_interp_grad,
                  ops::InterpolateOpGrad,
665
                  ops::InterpolateGradNoNeedBufferVarsInferer);
666 667 668
REGISTER_OPERATOR(bicubic_interp,
                  ops::InterpolateOp,
                  ops::InterpolateOpMaker,
X
xiaoting 已提交
669 670
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
671 672
REGISTER_OPERATOR(bicubic_interp_grad,
                  ops::InterpolateOpGrad,
673
                  ops::InterpolateGradNoNeedBufferVarsInferer);
674 675
REGISTER_OP_CPU_KERNEL(bilinear_interp,
                       ops::InterpolateKernel<float>,
676 677
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
678 679
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
                       ops::InterpolateGradKernel<float>,
680
                       ops::InterpolateGradKernel<double>);
681 682
REGISTER_OP_CPU_KERNEL(nearest_interp,
                       ops::InterpolateKernel<float>,
683 684
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
685 686
REGISTER_OP_CPU_KERNEL(nearest_interp_grad,
                       ops::InterpolateGradKernel<float>,
687
                       ops::InterpolateGradKernel<double>);
688 689
REGISTER_OP_CPU_KERNEL(trilinear_interp,
                       ops::InterpolateKernel<float>,
K
Kaipeng Deng 已提交
690 691
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
692 693
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad,
                       ops::InterpolateGradKernel<float>,
K
Kaipeng Deng 已提交
694
                       ops::InterpolateGradKernel<double>);
695 696 697
REGISTER_OPERATOR(linear_interp,
                  ops::InterpolateOp,
                  ops::InterpolateOpMaker,
698 699
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
700 701
REGISTER_OPERATOR(linear_interp_grad,
                  ops::InterpolateOpGrad,
702
                  ops::InterpolateGradNoNeedBufferVarsInferer);
703 704
REGISTER_OP_CPU_KERNEL(linear_interp,
                       ops::InterpolateKernel<float>,
705 706
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
707 708
REGISTER_OP_CPU_KERNEL(linear_interp_grad,
                       ops::InterpolateGradKernel<float>,
709
                       ops::InterpolateGradKernel<double>);
710 711
REGISTER_OP_CPU_KERNEL(bicubic_interp,
                       ops::InterpolateKernel<float>,
X
xiaoting 已提交
712
                       ops::InterpolateKernel<double>);
713 714
REGISTER_OP_CPU_KERNEL(bicubic_interp_grad,
                       ops::InterpolateGradKernel<float>,
X
xiaoting 已提交
715
                       ops::InterpolateGradKernel<double>);