interpolate_op.cc 18.6 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

12
#include "paddle/fluid/operators/interpolate_op.h"
S
sneaxiy 已提交
13
#include <memory>
14
#include <string>
15 16 17 18 19 20 21
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using framework::Tensor;
22
using DataLayout = framework::DataLayout;
23

K
Kaipeng Deng 已提交
24 25 26 27 28 29 30 31
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

  PADDLE_ENFORCE(
      "bilinear" == interp_method || "nearest" == interp_method,
      "Interpolation method can only be \"bilinear\" or \"nearest\" when "
      "Input(X) dimension is 4");
32 33
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
34

35 36 37 38 39 40 41 42 43
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 2,
        "Input(SizeTensor)'size of Op(interpolate) must be 2. "
        "Attr(out_shape)'s length must be 2 for 4-D input tensor.");
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
44 45 46 47 48 49 50
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
    }
    ctx->SetOutputDim("Out", dim_out);
51 52 53 54

    return;
  }

K
Kaipeng Deng 已提交
55
  int out_h, out_w;
56 57 58 59 60 61
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
    PADDLE_ENFORCE_EQ(scale_tensor.size(), 1,
                      "Scale's dimension size must be 1.");
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
62
  } else {
63 64 65
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
66 67 68 69 70 71
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
72 73 74 75 76 77 78
      // protect when input shape is -1
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
79 80 81 82 83 84 85 86 87 88 89
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
    PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
                      "OutSize's dimension size must be 1");
    PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
    ctx->ShareLoD("X", "Out");
    return;
  }

90 91 92 93 94 95 96
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
97 98 99 100 101 102 103 104 105
}

static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
  auto dim_x = ctx->GetInputDim("X");
  auto interp_method = ctx->Attrs().Get<std::string>("interp_method");

  PADDLE_ENFORCE("trilinear" == interp_method,
                 "Interpolation method can only be \"trilinear\" when Input(X) "
                 "dimension is 5");
106 107
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
K
Kaipeng Deng 已提交
108

109 110 111 112 113 114 115 116 117 118
  if (ctx->HasInputs("SizeTensor")) {
    // top prority size
    auto inputs_name = ctx->Inputs("SizeTensor");
    PADDLE_ENFORCE_EQ(
        inputs_name.size(), 3,
        "Input(SizeTensor)'s size of Op(interpolate) must be 3. "
        "Attr(out_shape)'s length must be 3 for 5-D input tensor.");
    int out_d = ctx->Attrs().Get<int>("out_d");
    int out_h = ctx->Attrs().Get<int>("out_h");
    int out_w = ctx->Attrs().Get<int>("out_w");
119 120 121 122 123 124 125
    framework::DDim dim_out;
    if (data_layout == DataLayout::kNCHW) {
      dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
    } else {
      dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
    }
    ctx->SetOutputDim("Out", dim_out);
126 127 128 129

    return;
  }

K
Kaipeng Deng 已提交
130
  int out_d, out_h, out_w;
131 132 133 134 135 136 137
  if (ctx->HasInput("Scale")) {
    auto scale_tensor = ctx->GetInputDim("Scale");
    PADDLE_ENFORCE_EQ(scale_tensor.size(), 1,
                      "Scale's dimension size must be 1");
    out_d = -1;
    out_h = -1;
    out_w = -1;
K
Kaipeng Deng 已提交
138
  } else {
139 140 141
    float scale = ctx->Attrs().Get<float>("scale");
    if (scale > 0) {
      // round down
142 143 144 145 146 147 148 149 150
      out_d = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[2] * scale)
                   : static_cast<int>(dim_x[1] * scale));
      out_h = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[3] * scale)
                   : static_cast<int>(dim_x[2] * scale));
      out_w = (data_layout == DataLayout::kNCHW
                   ? static_cast<int>(dim_x[4] * scale)
                   : static_cast<int>(dim_x[3] * scale));
151 152 153 154 155 156 157 158 159
      // protect when input shape is -1
      out_d = out_d > 0 ? out_d : -1;
      out_h = out_h > 0 ? out_h : -1;
      out_w = out_w > 0 ? out_w : -1;
    } else {
      out_d = ctx->Attrs().Get<int>("out_d");
      out_h = ctx->Attrs().Get<int>("out_h");
      out_w = ctx->Attrs().Get<int>("out_w");
    }
K
Kaipeng Deng 已提交
160 161 162 163 164 165 166 167 168 169 170
  }

  if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
    auto out_size_dim = ctx->GetInputDim("OutSize");
    PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
                      "OutSize's dimension size must be 1");
    PADDLE_ENFORCE_EQ(out_size_dim[0], 3, "OutSize's dim[0] must be 3");
    ctx->ShareLoD("X", "Out");
    return;
  }

171 172 173 174 175 176 177
  framework::DDim dim_out;
  if (data_layout == DataLayout::kNCHW) {
    dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
  } else {
    dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
  }
  ctx->SetOutputDim("Out", dim_out);
K
Kaipeng Deng 已提交
178 179
}

180
class InterpolateOp : public framework::OperatorWithKernel {
181 182 183 184 185 186
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"),
187
                   "Input(X) of InterpolateOp should not be null.");
188
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
189 190
                   "Output(Out) of InterpolationOp should not be null.");

191
    auto dim_x = ctx->GetInputDim("X");  // NCHW format
K
Kaipeng Deng 已提交
192 193 194 195 196 197 198 199 200
    PADDLE_ENFORCE(dim_x.size() == 4 || dim_x.size() == 5,
                   "Input(X) dimension must be 4 or 5");

    if (dim_x.size() == 4) {
      // shape check for 2D interpolate for input tensor shape NCHW
      Interpolate2DInferShapeCheck(ctx);
    } else {  // dim_x.size() == 5
      // shape check for 3D interpolate for input tensor shape NCDHW
      Interpolate3DInferShapeCheck(ctx);
201 202 203 204 205 206
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
207 208
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
209
  }
210 211 212 213 214 215 216 217 218 219

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
220 221
};

222
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
223 224 225
 public:
  void Make() override {
    AddInput("X",
226
             "The input tensor of interpolate operator, "
K
Kaipeng Deng 已提交
227 228
             "This is a 4-D tensor with shape of [N, C, H, W] or a "
             "5-D tensor with shape of [N, C, D, H, W].");
229
    AddInput("OutSize",
230
             "This is a 1-D tensor with two numbers to specify output size. "
K
Kaipeng Deng 已提交
231 232
             "It should be [output_height, output_width] when input is a 4-D "
             "tensor and should be [output_depth, output_height, output_width] "
233 234 235 236 237 238 239 240 241 242 243 244 245
             "when input is a 5-D tensor. It has a higher priority than "
             "the attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDispensable();
    AddInput("SizeTensor",
             "(vector<Tensor<int32>>, optional). If provided, interpolate will "
             "use this. The shape of the tensor in vector MUST BE [1]. "
             "It has the highest priority compare with Input(OutSize) and "
             "attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
        .AsDuplicable()
        .AsDispensable();
    AddInput("Scale",
             "This is a 1-D tensor with one number to specify output scale. "
             "It has the higher priority compare with attr(scale).")
246
        .AsDispensable();
247 248
    AddOutput("Out",
              "The output tensor of interpolate operator, "
K
Kaipeng Deng 已提交
249
              "This is a tensor in same rank with Input(X).");
250

251 252 253 254 255 256 257
    AddAttr<std::string>(
        "data_layout",
        "(string, default NCHW) Only used in "
        "an optional string from: \"NHWC\", \"NCHW\". "
        "Specify that the data format of the input and output data is "
        "channel_first or channel_last.")
        .SetDefault("NCHW");
K
Kaipeng Deng 已提交
258 259 260
    AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
    AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
    AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
D
dengkaipeng 已提交
261
    AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
262 263 264
    AddAttr<std::string>("interp_method",
                         "(string, default \"bilinear\"), interpolation "
                         "method, can be \"bilinear\" for "
K
Kaipeng Deng 已提交
265 266
                         "bilinear interpolation, \"trilinear\" for trilinear "
                         "interpolation and \"nearest\" for nearest "
267 268
                         "neighbor interpolation.")
        .SetDefault("bilinear");
269 270
    AddAttr<bool>(
        "align_corners",
T
Tink_Y 已提交
271
        "an optional bool. Defaults to True. "
272 273
        "If True, the centers of 4 corner pixels of the input and output "
        "tensors are aligned, preserving the values at the corner pixels, "
T
Tink_Y 已提交
274
        "If False, are not aligned")
275 276
        .SetDefault(true);
    AddAttr<int>("align_mode",
T
Tink_Y 已提交
277
                 "(int, default \'1\'), optional for bilinear interpolation, "
T
tink2123 已提交
278 279
                 "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
                 "can be \'1\' for src_idx = scale*dst_index .")
T
tink2123 已提交
280
        .SetDefault(1);
281
    AddComment(R"DOC(
282 283 284 285 286
          This operator samples input X to given output shape by using specified
          interpolation method, the interpolation methods can be \"nearest\"
          for nearest neighbor interpolation and \"bilinear\" for bilinear 
          interpolation.

287
          Nearest neighbor interpolation is to perform nearest neighbor interpolation
288
          in both the 3rd dimention(in height direction) and the 4th dimention(in width 
289 290
          direction) on input tensor.
            
291 292 293 294 295 296
          Bilinear interpolation is an extension of linear interpolation for 
          interpolating functions of two variables (e.g. H-direction and 
          W-direction in this op) on a rectilinear 2D grid. The key idea is 
          to perform linear interpolation first in one direction, and then 
          again in the other direction.

K
Kaipeng Deng 已提交
297 298 299 300 301
          Trilinear interpolation is an extension of linear interpolation for 
          interpolating functions of three variables (e.g. D-direction, 
          H-direction and W-direction in this op) on a rectilinear 3D grid. 
          The linear interpolation is performed on three directions.

T
tink2123 已提交
302
          Align_corners and align_mode are optinal parameters,the calculation method 
303 304 305 306
          of interpolation can be selected by them.
          
          Example:

T
tink2123 已提交
307
          For scale:
308 309 310 311 312 313 314 315 316 317 318 319
          
            if align_corners = True and out_{size}>1 :

              scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0)
            
            else:
              
              scale_{factor} = float(in_{size}/out_{size})
            
          
          Nearest neighbor interpolation:
          
T
tink2123 已提交
320
          if:
321 322 323 324 325 326 327 328
              align_corners = False

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor
              W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor

T
tink2123 已提交
329
          else:
330 331 332 333 334 335 336 337 338 339
              align_corners = True

              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})

          Bilinear interpolation:

T
tink2123 已提交
340
          if:
341 342 343 344 345 346 347 348 349
              align_corners = False , align_mode = 0
              
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


T
tink2123 已提交
350
          else:
351 352 353 354 355 356 357
           
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:

              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

K
Kaipeng Deng 已提交
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
          Trilinear interpolation:

          if:
              align_corners = False , align_mode = 0
              
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              
              D_out = (D_{in}+0.5) * scale_{factor} - 0.5
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5


          else:
           
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:

              D_out = D_{in} * scale_{factor}
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}
379 380
          

381
          For details of nearest neighbor interpolation, please refer to Wikipedia: 
382
          https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
383 384 385

          For details of bilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Bilinear_interpolation
K
Kaipeng Deng 已提交
386 387 388

          For details of trilinear interpolation, please refer to Wikipedia: 
          https://en.wikipedia.org/wiki/Trilinear_interpolation
389 390 391 392
         )DOC");
  }
};

393
class InterpolateOpGrad : public framework::OperatorWithKernel {
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should not be null");
    auto dim_x = ctx->GetInputDim("X");
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
410 411 412
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Out")),
                                   ctx.GetPlace());
413
  }
414 415 416 417 418 419 420 421 422 423

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "SizeTensor" || var_name == "Scale") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
424 425
};

S
sneaxiy 已提交
426 427 428 429 430 431 432 433 434
class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOp().Type() + "_grad");
    op->SetInput("X", Input("X"));
435 436 437
    if (ForwardOp().Inputs().count("SizeTensor") > 0) {
      op->SetInput("SizeTensor", Input("SizeTensor"));
    }
S
sneaxiy 已提交
438 439 440
    if (ForwardOp().Inputs().count("OutSize") > 0) {
      op->SetInput("OutSize", Input("OutSize"));
    }
441 442 443
    if (ForwardOp().Inputs().count("Scale") > 0) {
      op->SetInput("Scale", Input("Scale"));
    }
S
sneaxiy 已提交
444 445 446 447 448 449 450 451 452 453
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());
    return op;
  }
};

DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(InterpolateGradNoNeedBufferVarsInference,
                                      "X");

454 455 456 457
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
458
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
S
sneaxiy 已提交
459 460 461
                  ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
                  ops::InterpolateGradNoNeedBufferVarsInference);
462
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
S
sneaxiy 已提交
463 464 465
                  ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
                  ops::InterpolateGradNoNeedBufferVarsInference);
K
Kaipeng Deng 已提交
466 467 468 469
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
                  ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
                  ops::InterpolateGradNoNeedBufferVarsInference);
470 471 472 473 474 475
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
476 477
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
478
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
479
                       ops::InterpolateGradKernel<double>);
K
Kaipeng Deng 已提交
480 481 482 483 484
REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
                       ops::InterpolateKernel<double>,
                       ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
                       ops::InterpolateGradKernel<double>);