interpolate_op.cc 27.0 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

12
#include "paddle/fluid/operators/interpolate_op.h"
S
sneaxiy 已提交
13
#include <memory>
14
#include <string>
15 16
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
17 18 19
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
20 21 22 23 24

namespace paddle {
namespace operators {

using framework::Tensor;
25
using DataLayout = framework::DataLayout;
26

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

  PADDLE_ENFORCE_EQ("linear", interp_method,
                    platform::errors::InvalidArgument(
                        "Interpolation method can only be \"linear\" when"
                        "Input(X) dimension is 3, but got method = %s .",
                        interp_method));
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));

  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 1,
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'size of Op(interpolate) must be 1. "
            "Attr(out_shape)'s length must be 1 for 3-D input tensor, but got "
            "size = %d .",
            inputs_name.size()));
    int out_w = ctx->Attrs().Get<int>("out_w");
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_w};
    } else {
      dim_out = {dim_x[0], out_w, dim_x[2]};
    }
    ctx->SetOutputDim("Out", dim_out);

    return;
  }

  int out_w;
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
    PADDLE_ENFORCE_EQ(
        scale_tensor.size(), 1,
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got dimension = %d .",
            scale_tensor.size()));
    out_w = -1;
  } else {
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      // protect when input shape is -1
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_w = ctx->Attrs().Get<int>("out_w");
    }
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
    PADDLE_ENFORCE_EQ(
        out_size_dim.size(), 1,
        platform::errors::InvalidArgument(
            "OutSize's dimension size must be 1, but got dimention = %d .",
            out_size_dim.size()));
    PADDLE_ENFORCE_EQ(out_size_dim[0], 1, platform::errors::InvalidArgument(
                                              "OutSize's dim[0] must be 1"));
    ctx->ShareLoD("X", "Out");
    return;
  }

  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_w};
  } else {
    dim_out = {dim_x[0], out_w, dim_x[2]};
  }
  ctx->SetOutputDim("Out", dim_out);
}

K
Kaipeng Deng 已提交
106 107 108 109
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

110 111 112 113 114 115 116
  PADDLE_ENFORCE_EQ("bilinear" == interp_method || "nearest" == interp_method ||
                        "bicubic" == interp_method,
                    true, platform::errors::InvalidArgument(
                              "Interpolation method can only be \"bilinear\" "
                              "or \"nearest\" or \"bicubic\" when "
                              "Input(X) dimension is 4, but got method is %s.",
                              interp_method));
117 118
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
119

120 121 122 123 124
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 2,
125 126 127 128 129
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'size of Op(interpolate) must be 2. "
            "Attr(out_shape)'s length must be 2 for 4-D input "
            "tensor, but got size = %d .",
            inputs_name.size()));
130 131
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
132 133 134 135 136 137 138
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
    }
    ctx->SetOutputDim("Out", dim_out);
139 140 141 142

    return;
  }

K
Kaipeng Deng 已提交
143
  int out_h, out_w;
144 145
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
146 147 148 149 150
    PADDLE_ENFORCE_EQ(
        scale_tensor.size(), 1,
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got dimension = %d .",
            scale_tensor.size()));
151 152
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
153
  } else {
154 155 156
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
157 158 159 160 161 162
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
163 164 165 166 167 168 169
      // protect when input shape is -1
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
170 171 172 173
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
174 175
    PADDLE_ENFORCE_EQ(
        out_size_dim.size(), 1,
176 177 178
        platform::errors::InvalidArgument("OutSize's dimension size must be 1, "
                                          "but got dimension size is %d .",
                                          out_size_dim.size()));
179 180 181
    PADDLE_ENFORCE_EQ(
        out_size_dim[0], 2,
        platform::errors::InvalidArgument(
182
            "OutSize's dimension[0] must be 2, but got dimension[0] is %d .",
183
            out_size_dim[0]));
K
Kaipeng Deng 已提交
184 185 186 187
    ctx->ShareLoD("X", "Out");
    return;
  }

188 189 190 191 192 193 194
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
195 196 197 198 199 200
}

static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

201 202 203 204 205 206
  PADDLE_ENFORCE_EQ(
      "trilinear", interp_method,
      platform::errors::InvalidArgument(
          "Interpolation method can only be \"trilinear\" when Input(X) "
          "dimension is 5, but got method = %s .",
          interp_method));
207 208
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
209

210 211 212 213 214
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 3,
215 216 217 218 219
        platform::errors::InvalidArgument(
            "Input(SizeTensor)'s size of Op(interpolate) must be 3. "
            "Attr(out_shape)'s length must be 3 for 5-D input "
            "tensor, but got size = %d .",
            inputs_name.size()));
220 221 222
    int out_d = ctx->Attrs().Get<int>("out_d");
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
223 224 225 226 227 228 229
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
    }
    ctx->SetOutputDim("Out", dim_out);
230 231 232 233

    return;
  }

K
Kaipeng Deng 已提交
234
  int out_d, out_h, out_w;
235 236
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
237 238 239 240 241
    PADDLE_ENFORCE_EQ(
        scale_tensor.size(), 1,
        platform::errors::InvalidArgument(
            "Scale's dimension size must be 1, but got size = %d .",
            scale_tensor.size()));
242 243 244
    out_d = -1;
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
245
  } else {
246 247 248
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
249 250 251 252 253 254 255 256 257
      out_d = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[4] * scale)
                   : static_cast<int>(dim_x[3] * scale));
258 259 260 261 262 263 264 265 266
      // protect when input shape is -1
      out_d = out_d > 0 ? out_d : -1;
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_d = ctx->Attrs().Get<int>("out_d");
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
267 268 269 270
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
271 272 273 274 275
    PADDLE_ENFORCE_EQ(
        out_size_dim.size(), 1,
        platform::errors::InvalidArgument(
            "OutSize's dimension size must be 1, but got size is %d.",
            out_size_dim.size()));
276
    PADDLE_ENFORCE_EQ(out_size_dim[0], 3,
277 278 279
                      platform::errors::InvalidArgument(
                          "OutSize's dim[0] must be 3, but got size is %d.",
                          out_size_dim[0]));
K
Kaipeng Deng 已提交
280 281 282 283
    ctx->ShareLoD("X", "Out");
    return;
  }

284 285 286 287 288 289 290
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
291 292
}

293
class InterpolateOp : public framework::OperatorWithKernel {
294 295 296 297 298
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
299 300
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Interpolate");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Interpolate");
301

302
    auto dim_x = ctx->GetInputDim("X");  // NCHW format
303 304 305 306 307 308 309 310 311
    PADDLE_ENFORCE(
        dim_x.size() == 3 || dim_x.size() == 4 || dim_x.size() == 5,
        platform::errors::Unimplemented(
            "Input(X) dimension must be 3, 4 or 5, but got dimension = %d .",
            dim_x.size()));
    if (dim_x.size() == 3) {
      // shape check for 1D interpolate for input tensor shape NCHW
      Interpolate1DInferShapeCheck(ctx);
    } else if (dim_x.size() == 4) {
K
Kaipeng Deng 已提交
312 313 314 315 316
      // shape check for 2D interpolate for input tensor shape NCHW
      Interpolate2DInferShapeCheck(ctx);
    } else {  // dim_x.size() == 5
      // shape check for 3D interpolate for input tensor shape NCDHW
      Interpolate3DInferShapeCheck(ctx);
317 318 319 320 321 322
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
323 324 325 326 327 328 329 330 331 332 333 334 335
    framework::DataLayout layout = framework::DataLayout::kAnyLayout;
    framework::LibraryType library = framework::LibraryType::kPlain;

#ifdef PADDLE_WITH_MKLDNN
    auto interp_method = ctx.Attr<std::string>("interp_method");
    // TODO(danqing): support other interp_method
    if (this->CanMKLDNNBeUsed(ctx) &&
        (interp_method == "nearest" || interp_method == "bilinear")) {
      layout = framework::DataLayout::kMKLDNN;
      library = framework::LibraryType::kMKLDNN;
    }
#endif

336
    return framework::OpKernelType(
337 338
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
        layout, library);
339
  }
340 341 342 343

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
#ifdef PADDLE_WITH_MKLDNN
    if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
        (tensor.layout() != framework::DataLayout::kMKLDNN)) {
      auto attrs = Attrs();
      auto ar = paddle::framework::AttrReader(attrs);
      const std::string data_format = ar.Get<std::string>("data_layout");
      auto dl = framework::StringToDataLayout(data_format);
      // Some models may have intentionally set "AnyLayout" for pool
      // op. Treat this as NCHW (default data_format value)
      if (dl != framework::DataLayout::kAnyLayout) {
        return framework::OpKernelType(expected_kernel_type.data_type_,
                                       tensor.place(), dl);
      }
    }
#endif
359 360 361 362 363 364
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
365 366
};

367
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
368 369 370
 public:
  void Make() override {
    AddInput("X",
371
             "The input tensor of interpolate operator, "
K
Kaipeng Deng 已提交
372 373
             "This is a 4-D tensor with shape of [N, C, H, W] or a "
             "5-D tensor with shape of [N, C, D, H, W].");
374
    AddInput("OutSize",
375
             "This is a 1-D tensor with two numbers to specify output size. "
K
Kaipeng Deng 已提交
376 377
             "It should be [output_height, output_width] when input is a 4-D "
             "tensor and should be [output_depth, output_height, output_width] "
378 379 380 381 382 383 384 385 386 387 388 389 390
             "when input is a 5-D tensor. It has a higher priority than "
             "the attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDispensable();
    AddInput("SizeTensor",
             "(vector<Tensor<int32>>, optional). If provided, interpolate will "
             "use this. The shape of the tensor in vector MUST BE [1]. "
             "It has the highest priority compare with Input(OutSize) and "
             "attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDuplicable()
        .AsDispensable();
    AddInput("Scale",
             "This is a 1-D tensor with one number to specify output scale. "
             "It has the higher priority compare with attr(scale).")
391
        .AsDispensable();
392 393
    AddOutput("Out",
              "The output tensor of interpolate operator, "
K
Kaipeng Deng 已提交
394
              "This is a tensor in same rank with Input(X).");
395

396 397 398 399 400 401 402
    AddAttr<std::string>(
        "data_layout",
        "(string, default NCHW) Only used in "
        "an optional string from: \"NHWC\", \"NCHW\". "
        "Specify that the data format of the input and output data is "
        "channel_first or channel_last.")
        .SetDefault("NCHW");
K
Kaipeng Deng 已提交
403 404 405
    AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
    AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
    AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
D
dengkaipeng 已提交
406
    AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
407 408
    AddAttr<std::string>("interp_method",
                         "(string, default \"bilinear\"), interpolation "
409 410
                         "method, can be \"linear\" for linear interpolation"
                         ",\"bilinear\" for "
K
Kaipeng Deng 已提交
411 412
                         "bilinear interpolation, \"trilinear\" for trilinear "
                         "interpolation and \"nearest\" for nearest "
X
xiaoting 已提交
413 414
                         "neighbor interpolation, and \"bicubic\" for bicubic"
                         "interpolation.")
415
        .SetDefault("bilinear");
416 417
    AddAttr<bool>(
        "align_corners",
T
Tink_Y 已提交
418
        "an optional bool. Defaults to True. "
419 420
        "If True, the centers of 4 corner pixels of the input and output "
        "tensors are aligned, preserving the values at the corner pixels, "
T
Tink_Y 已提交
421
        "If False, are not aligned")
422 423
        .SetDefault(true);
    AddAttr<int>("align_mode",
T
Tink_Y 已提交
424
                 "(int, default \'1\'), optional for bilinear interpolation, "
T
tink2123 已提交
425 426
                 "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
                 "can be \'1\' for src_idx = scale*dst_index .")
T
tink2123 已提交
427
        .SetDefault(1);
428 429 430
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
431
    AddComment(R"DOC(
432 433 434
          This operator samples input X to given output shape by using specified
          interpolation method, the interpolation methods can be \"nearest\"
          for nearest neighbor interpolation and \"bilinear\" for bilinear 
435
          interpolation and \"linear\" for linear interpolation..
436

437
          Nearest neighbor interpolation is to perform nearest neighbor interpolation
T
tianshuo78520a 已提交
438
          in both the 3rd dimension(in height direction) and the 4th dimension(in width 
439
          direction) on input tensor.
440 441 442 443
           
          Linear interpolation is the method of using a line connecting two known quantities 
          to determine the value of an unknown quantity between the two known quantities. 
          
444 445 446 447 448 449
          Bilinear interpolation is an extension of linear interpolation for 
          interpolating functions of two variables (e.g. H-direction and 
          W-direction in this op) on a rectilinear 2D grid. The key idea is 
          to perform linear interpolation first in one direction, and then 
          again in the other direction.

K
Kaipeng Deng 已提交
450 451 452 453 454
          Trilinear interpolation is an extension of linear interpolation for 
          interpolating functions of three variables (e.g. D-direction, 
          H-direction and W-direction in this op) on a rectilinear 3D grid. 
          The linear interpolation is performed on three directions.

X
xiaoting 已提交
455 456 457 458 459
          Bicubic interpolation is an extension of cubic interpolation for interpolating
          data points on a two-dimensional regular grid. The interpolated surface is
          smoother than corresponding surfaces obtained by bilinear interpolation or
          nearest-neighbor interpolation.

T
tianshuo78520a 已提交
460
          Align_corners and align_mode are optional parameters,the calculation method 
461 462 463 464
          of interpolation can be selected by them.
          
          Example:

T
tink2123 已提交
465
          For scale:
466 467 468 469 470 471 472 473 474 475 476 477
          
            if align_corners = True and out_{size}>1 :

              scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0)
            
            else:
              
              scale_{factor} = float(in_{size}/out_{size})
            
          
          Nearest neighbor interpolation:
          
T
tink2123 已提交
478
          if:
479 480 481 482 483 484 485 486
              align_corners = False

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor
              W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor

T
tink2123 已提交
487
          else:
488 489 490 491 492 493 494 495 496 497
              align_corners = True

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})

          Bilinear interpolation:

T
tink2123 已提交
498
          if:
499 500 501 502 503 504 505 506 507
              align_corners = False , align_mode = 0
              
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


T
tink2123 已提交
508
          else:
509 510 511 512 513 514 515
           
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

K
Kaipeng Deng 已提交
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
          Trilinear interpolation:

          if:
              align_corners = False , align_mode = 0
              
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              
              D_out = (D_{in}+0.5) * scale_{factor} - 0.5
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


          else:
           
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:

              D_out = D_{in} * scale_{factor}
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
X
xiaoting 已提交
537 538 539 540 541 542 543 544 545 546 547 548 549 550

          Bicubic interpolation:

          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
551

552
          For details of nearest neighbor interpolation, please refer to Wikipedia: 
553
          https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
554 555 556

          For details of bilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Bilinear_interpolation
K
Kaipeng Deng 已提交
557 558 559

          For details of trilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Trilinear_interpolation
X
xiaoting 已提交
560 561 562

          For details of bicubic interpolation, please refer to Wikipedia:
          https://en.wikipedia.org/wiki/Bicubic_interpolation
563 564 565 566
         )DOC");
  }
};

567
class InterpolateOpGrad : public framework::OperatorWithKernel {
568 569 570 571 572
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
573 574 575 576
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InterpolateGrad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   "Out@GRAD", "InterpolateGrad");

577 578 579 580 581 582 583 584
    auto dim_x = ctx->GetInputDim("X");
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
585 586 587
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.GetPlace());
588
  }
589 590 591 592 593 594 595 596 597 598

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
599 600
};

H
hong 已提交
601 602
template <typename T>
class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
603
 public:
H
hong 已提交
604
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
605 606

 protected:
607
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
608 609 610 611
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("X", this->Input("X"));
    if (this->HasInput("SizeTensor") > 0) {
      op->SetInput("SizeTensor", this->Input("SizeTensor"));
612
    }
H
hong 已提交
613 614
    if (this->HasInput("OutSize") > 0) {
      op->SetInput("OutSize", this->Input("OutSize"));
S
sneaxiy 已提交
615
    }
H
hong 已提交
616 617
    if (this->HasInput("Scale") > 0) {
      op->SetInput("Scale", this->Input("Scale"));
618
    }
H
hong 已提交
619 620 621
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
622 623 624
  }
};

625
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInferer,
626
                                    "X");
S
sneaxiy 已提交
627

628 629 630 631
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
632
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
H
hong 已提交
633 634
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
S
sneaxiy 已提交
635
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
636
                  ops::InterpolateGradNoNeedBufferVarsInferer);
637
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
H
hong 已提交
638 639
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
S
sneaxiy 已提交
640
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
641
                  ops::InterpolateGradNoNeedBufferVarsInferer);
K
Kaipeng Deng 已提交
642
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
H
hong 已提交
643 644
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
K
Kaipeng Deng 已提交
645
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
646
                  ops::InterpolateGradNoNeedBufferVarsInferer);
X
xiaoting 已提交
647 648 649 650
REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad,
651
                  ops::InterpolateGradNoNeedBufferVarsInferer);
652 653 654 655 656 657
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
658 659
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
660
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
661
                       ops::InterpolateGradKernel<double>);
K
Kaipeng Deng 已提交
662 663 664 665 666
REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
667 668 669 670
REGISTER_OPERATOR(linear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
                  ops::InterpolateGradMaker<paddle::framework::OpDesc>,
                  ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_interp_grad, ops::InterpolateOpGrad,
671
                  ops::InterpolateGradNoNeedBufferVarsInferer);
672 673 674 675 676
REGISTER_OP_CPU_KERNEL(linear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(linear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
X
xiaoting 已提交
677 678 679 680
REGISTER_OP_CPU_KERNEL(bicubic_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);