未验证 提交 89d7d866 编写于 作者: X xiaoting 提交者: GitHub

add intepolte_v2 (#26520)

* add intepolte_v2

* fix linear interp

* polish unittest, test=develop

* update code samples to 2.0 API, test=develop

* remove warning, test_develop

* add name in attrs, test=develop

* polish code, test=develop

* change Align to align, test=develop

* fix unittest in py3,test=develop

* fix coverage, test=develop

* fix coverage, test=develop

* fix for windows ci, test=develop

* fix coverage, test=develop
上级 fc5acdd0
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using DataLayout = framework::DataLayout;
static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE_EQ("linear", interp_method,
platform::errors::InvalidArgument(
"Interpolation method can only be \"linear\" when"
"Input(X) dimension is 3, but got method = %s .",
interp_method));
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
if (ctx->HasInputs("SizeTensor")) {
// top prority size
auto inputs_name = ctx->Inputs("SizeTensor");
PADDLE_ENFORCE_EQ(
inputs_name.size(), 1,
platform::errors::InvalidArgument(
"Input(SizeTensor)'size of Op(interpolate) must be 1. "
"Attr(out_shape)'s length must be 1 for 3-D input tensor, but got "
"size = %d .",
inputs_name.size()));
int out_w = ctx->Attrs().Get<int>("out_w");
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_w};
} else {
dim_out = {dim_x[0], out_w, dim_x[2]};
}
ctx->SetOutputDim("Out", dim_out);
return;
}
int out_w;
if (ctx->HasInput("Scale")) {
auto scale_tensor = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ(
scale_tensor.size(), 1,
platform::errors::InvalidArgument(
"Scale's dimension size must be 1, but got dimension = %d .",
scale_tensor.size()));
PADDLE_ENFORCE_EQ(
scale_tensor[0], 1,
platform::errors::InvalidArgument(
"Scale's shape must be 1, but got shape = %d .", scale_tensor[0]));
// out_w = -1;
} else {
auto scale = ctx->Attrs().Get<std::vector<float>>("scale");
if (scale.size() > 0) {
float scale_w = -1;
scale_w = scale[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
if (scale_w > 0.) {
// round down
out_w = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale_w)
: static_cast<int>(dim_x[1] * scale_w));
// protect when input shape is -1
out_w = out_w > 0 ? out_w : -1;
}
} else {
out_w = ctx->Attrs().Get<int>("out_w");
}
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(
out_size_dim.size(), 1,
platform::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got dimention = %d .",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(out_size_dim[0], 1, platform::errors::InvalidArgument(
"OutSize's dim[0] must be 1"));
ctx->ShareLoD("X", "Out");
return;
}
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_w};
} else {
dim_out = {dim_x[0], out_w, dim_x[2]};
}
ctx->SetOutputDim("Out", dim_out);
}
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method ||
"bicubic" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4, but got method = %s .",
interp_method);
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
if (ctx->HasInputs("SizeTensor")) {
// top prority size
auto inputs_name = ctx->Inputs("SizeTensor");
PADDLE_ENFORCE_EQ(
inputs_name.size(), 2,
platform::errors::InvalidArgument(
"Input(SizeTensor)'size of Op(interpolate) must be 2. "
"Attr(out_shape)'s length must be 2 for 4-D input "
"tensor, but got size = %d .",
inputs_name.size()));
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_h, out_w};
} else {
dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
}
ctx->SetOutputDim("Out", dim_out);
return;
}
int out_h, out_w;
if (ctx->HasInput("Scale")) {
auto scale_tensor = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ(
scale_tensor.size(), 1,
platform::errors::InvalidArgument(
"Scale's dimension size must be 1, but got dimension = %d .",
scale_tensor.size()));
PADDLE_ENFORCE_EQ(scale_tensor[0] == 2 || scale_tensor[0] == 1, true,
platform::errors::InvalidArgument(
"Scale's shape must be 2 or 1, but got shape = %d .",
scale_tensor[0]));
// out_h = -1;
// out_w = -1;
} else {
auto scale = ctx->Attrs().Get<std::vector<float>>("scale");
if (scale.size() > 0) {
float scale_h = -1;
float scale_w = -1;
scale_h = scale[0];
scale_w = scale[1];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
if (scale_h > 0. && scale_w > 0.) {
// round down
out_h = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale_h)
: static_cast<int>(dim_x[1] * scale_h));
out_w = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[3] * scale_w)
: static_cast<int>(dim_x[2] * scale_w));
// protect when input shape is -1
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
}
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
}
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(
out_size_dim.size(), 1,
platform::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got dimension = %d .",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(
out_size_dim[0], 2,
platform::errors::InvalidArgument(
"OutSize's dim[0] must be 2, but got dimention = %d .",
out_size_dim[0]));
ctx->ShareLoD("X", "Out");
return;
}
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_h, out_w};
} else {
dim_out = {dim_x[0], out_h, out_w, dim_x[3]};
}
ctx->SetOutputDim("Out", dim_out);
}
static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE_EQ(
"trilinear", interp_method,
platform::errors::InvalidArgument(
"Interpolation method can only be \"trilinear\" when Input(X) "
"dimension is 5, but got method = %s .",
interp_method));
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
if (ctx->HasInputs("SizeTensor")) {
// top prority size
auto inputs_name = ctx->Inputs("SizeTensor");
PADDLE_ENFORCE_EQ(
inputs_name.size(), 3,
platform::errors::InvalidArgument(
"Input(SizeTensor)'s size of Op(interpolate) must be 3. "
"Attr(out_shape)'s length must be 3 for 5-D input "
"tensor, but got size = %d .",
inputs_name.size()));
int out_d = ctx->Attrs().Get<int>("out_d");
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
} else {
dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
}
ctx->SetOutputDim("Out", dim_out);
return;
}
int out_d, out_h, out_w;
if (ctx->HasInput("Scale")) {
auto scale_tensor = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ(
scale_tensor.size(), 1,
platform::errors::InvalidArgument(
"Scale's dimension size must be 1, but got size = %d .",
scale_tensor.size()));
PADDLE_ENFORCE_EQ(scale_tensor[0] == 3 || scale_tensor[0] == 1, true,
platform::errors::InvalidArgument(
"Scale's shape must be 3 or 1, but got shape = %d .",
scale_tensor[0]));
// out_d = -1;
// out_h = -1;
// out_w = -1;
} else {
auto scale = ctx->Attrs().Get<std::vector<float>>("scale");
if (scale.size() > 0) {
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
// round down
out_d = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale_d)
: static_cast<int>(dim_x[1] * scale_d));
out_h = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[3] * scale_h)
: static_cast<int>(dim_x[2] * scale_h));
out_w = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[4] * scale_w)
: static_cast<int>(dim_x[3] * scale_w));
// protect when input shape is -1
out_d = out_d > 0 ? out_d : -1;
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
}
} else {
out_d = ctx->Attrs().Get<int>("out_d");
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
}
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1, but got size =%d .",
out_size_dim.size());
PADDLE_ENFORCE_EQ(out_size_dim[0], 3,
"OutSize's dim[0] must be 3, but got size = %d .",
out_size_dim[0]);
ctx->ShareLoD("X", "Out");
return;
}
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_d, out_h, out_w};
} else {
dim_out = {dim_x[0], out_d, out_h, out_w, dim_x[4]};
}
ctx->SetOutputDim("Out", dim_out);
}
class InterpolateV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of InterpolateV2Op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of InterpolationOp should not be null.");
auto dim_x = ctx->GetInputDim("X"); // NCHW format
PADDLE_ENFORCE(
dim_x.size() == 3 || dim_x.size() == 4 || dim_x.size() == 5,
platform::errors::Unimplemented(
"Input(X) dimension must be 3, 4 or 5, but got dimension = %d .",
dim_x.size()));
if (dim_x.size() == 3) {
// shape check for 1D interpolate for input tensor shape NCHW
Interpolate1DInferShapeCheck(ctx);
} else if (dim_x.size() == 4) {
// shape check for 2D interpolate for input tensor shape NCHW
Interpolate2DInferShapeCheck(ctx);
} else { // dim_x.size() == 5
// shape check for 3D interpolate for input tensor shape NCDHW
Interpolate3DInferShapeCheck(ctx);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "SizeTensor" || var_name == "Scale") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class InterpolateV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, W] or a "
"5-D tensor with shape of [N, C, D, H, W].");
AddInput("OutSize",
"This is a 1-D tensor with two numbers to specify output size. "
"It should be [output_height, output_width] when input is a 4-D "
"tensor and should be [output_depth, output_height, output_width] "
"when input is a 5-D tensor. It has a higher priority than "
"the attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
.AsDispensable();
AddInput("SizeTensor",
"(vector<Tensor<int32>>, optional). If provided, interpolate will "
"use this. The shape of the tensor in vector MUST BE [1]. "
"It has the highest priority compare with Input(OutSize) and "
"attr(out_d), attr(out_h), attr(out_w) and attr(scale).")
.AsDuplicable()
.AsDispensable();
AddInput("Scale",
"This is a 1-D tensor with one number to specify output scale. "
"It has the higher priority compare with attr(scale).")
.AsDispensable();
AddOutput("Out",
"The output tensor of interpolate operator, "
"This is a tensor in same rank with Input(X).");
AddAttr<std::string>(
"data_layout",
"(string, default NCHW) Only used in "
"an optional string from: \"NHWC\", \"NCHW\". "
"Specify that the data format of the input and output data is "
"channel_first or channel_last.")
.SetDefault("NCHW");
AddAttr<int>("out_d", "output depth of interpolate op.").SetDefault(0);
AddAttr<int>("out_h", "output height of interpolate op.").SetDefault(0);
AddAttr<int>("out_w", "output width of interpolate op.").SetDefault(0);
AddAttr<std::vector<float>>("scale", "scale_d factor of interpolate op.")
.SetDefault(std::vector<float>{});
AddAttr<std::string>("interp_method",
"(string, default \"bilinear\"), interpolation "
"method, can be \"linear\" for linear interpolation"
",\"bilinear\" for "
"bilinear interpolation, \"trilinear\" for trilinear "
"interpolation and \"nearest\" for nearest "
"neighbor interpolation, and \"bicubic\" for bicubic"
"interpolation.")
.SetDefault("bilinear");
AddAttr<bool>(
"align_corners",
"an optional bool. Defaults to True. "
"If True, the centers of 4 corner pixels of the input and output "
"tensors are aligned, preserving the values at the corner pixels, "
"If False, are not aligned")
.SetDefault(true);
AddAttr<int>("align_mode",
"(int, default \'1\'), optional for bilinear interpolation, "
"can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
"can be \'1\' for src_idx = scale*dst_index .")
.SetDefault(1);
AddComment(R"DOC(
This operator samples input X to given output shape by using specified
interpolation method, the interpolation methods can be \"nearest\"
for nearest neighbor interpolation and \"bilinear\" for bilinear
interpolation and \"linear\" for linear interpolation..
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in both the 3rd dimension(in height direction) and the 4th dimension(in width
direction) on input tensor.
Linear interpolation is the method of using a line connecting two known quantities
to determine the value of an unknown quantity between the two known quantities.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Bicubic interpolation is an extension of cubic interpolation for interpolating
data points on a two-dimensional regular grid. The interpolated surface is
smoother than corresponding surfaces obtained by bilinear interpolation or
nearest-neighbor interpolation.
Align_corners and align_mode are optional parameters,the calculation method
of interpolation can be selected by them.
Example:
For scale:
if align_corners = True and out_{size}>1 :
scale_{factor} = (in_{size}-1.0)/(out_{size}-1.0)
else:
scale_{factor} = float(in_{size}/out_{size})
Nearest neighbor interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = \left \lfloor {H_{in} * scale_{}factor}} \right \rfloor
W_out = \left \lfloor {W_{in} * scale_{}factor}} \right \rfloor
else:
align_corners = True
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = round(H_{in} * scale_{factor})
W_out = round(W_{in} * scale_{factor})
Bilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Bicubic interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interp_v2olation
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interp_v2olation
For details of bicubic interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bicubic_interpolation
)DOC");
}
};
class InterpolateV2OpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "SizeTensor" || var_name == "Scale") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
template <typename T>
class InterpolateV2GradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
if (this->HasInput("SizeTensor") > 0) {
op->SetInput("SizeTensor", this->Input("SizeTensor"));
}
if (this->HasInput("OutSize") > 0) {
op->SetInput("OutSize", this->Input("OutSize"));
}
if (this->HasInput("Scale") > 0) {
op->SetInput("Scale", this->Input("Scale"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateV2GradNoNeedBufferVarsInferer,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(bilinear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bilinear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(nearest_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nearest_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(trilinear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(trilinear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(bicubic_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bicubic_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(bilinear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OPERATOR(linear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(linear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(linear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
__global__ void KeNearestNeighborInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
if (data_layout == DataLayout::kNCHW) {
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
} else {
out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
}
}
template <typename T>
__global__ void KeNearestNeighborInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
} else {
in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
const T out_pos = out[out_id_h * output_w + out_id_w];
platform::CudaAtomicAdd(in_pos, out_pos);
}
}
template <typename T>
__global__ void KeLinearInterpFw(const T* in, const size_t in_img_w,
const size_t input_w, T* out,
const size_t out_img_w, const size_t output_h,
const size_t output_w,
const size_t num_channels, const float ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idx = tid % out_img_w;
} else {
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
const T* in_pos =
&in[out_id_h * out_id_w + channel_id * in_img_size + in_img_idx];
// linear interpolation
out[out_id_h * output_w + out_id_w] =
w2lambda * in_pos[0] + w1lambda * in_pos[w_id];
} else {
const T* in_pos =
&in[out_id_h * input_w + in_img_idx * num_channels + channel_id];
// linear interpolation
out[out_id_h * output_w + out_id_w] =
w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels];
}
}
}
template <typename T>
__global__ void KeLinearInterpBw(T* in, const size_t in_img_w,
const size_t input_w, const T* out,
const size_t out_img_w, const size_t output_h,
const size_t output_w,
const size_t num_channels, const T ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idx = tid % out_img_w;
} else {
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5
: ratio_w * out_img_idx;
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size + in_img_idx];
} else {
in_pos = &in[out_id_h * input_w + in_img_idx * num_channels + channel_id];
}
const T* out_pos = &out[out_id_w];
if (data_layout == DataLayout::kNCHW) {
platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id], w1lambda * out_pos[0]);
} else {
platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
w1lambda * out_pos[0]);
}
}
}
template <typename T>
__global__ void KeBilinearInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
// bilinear interpolation
out[out_id_h * output_w + out_id_w] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
} else {
const T* in_pos =
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
// bilinear interpolation
out[out_id_h * output_w + out_id_w] =
h2lambda *
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w * num_channels] +
w1lambda * in_pos[h_id * in_img_w * num_channels +
w_id * num_channels]);
}
}
}
template <typename T>
__global__ void KeBilinearInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idy = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5
: ratio_h * out_img_idy;
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5
: ratio_w * out_img_idx;
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
} else {
in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
const T* out_pos = &out[out_id_h * output_w + out_id_w];
if (data_layout == DataLayout::kNCHW) {
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w],
h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id],
h1lambda * w1lambda * out_pos[0]);
} else {
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w * num_channels],
h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(
&in_pos[h_id * in_img_w * num_channels + w_id * num_channels],
h1lambda * w1lambda * out_pos[0]);
}
}
}
template <typename T>
__global__ void KeTrilinearInterpFw(
const T* in, const size_t in_img_d, const size_t in_img_h,
const size_t in_img_w, const size_t input_h, const size_t input_w, T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const float ratio_d, const float ratio_h, const float ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = align_flag
? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
(in_img_idt * in_img_h + in_img_idy) * in_img_w +
in_img_idx;
const T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
const T* in_pos2 = &in[in_pos2_idx];
// trilinear interpolation
out[out_id_h * output_w + out_id_w] =
d2lambda *
(h2lambda * (w2lambda * in_pos1[0] + w1lambda * in_pos1[w_id]) +
h1lambda * (w2lambda * in_pos1[h_id * in_img_w] +
w1lambda * in_pos1[h_id * in_img_w + w_id])) +
d1lambda *
(h2lambda * (w2lambda * in_pos2[0] + w1lambda * in_pos2[w_id]) +
h1lambda * (w2lambda * in_pos2[h_id * in_img_w] +
w1lambda * in_pos2[h_id * in_img_w + w_id]));
} else {
int in_pos1_idx = out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id;
const T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w * num_channels;
const T* in_pos2 = &in[in_pos2_idx];
// trilinear interpolation
out[out_id_h * output_w + out_id_w] =
d2lambda *
(h2lambda * (w2lambda * in_pos1[0] +
w1lambda * in_pos1[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos1[h_id * in_img_w * num_channels] +
w1lambda * in_pos1[h_id * in_img_w * num_channels +
w_id * num_channels])) +
d1lambda *
(h2lambda * (w2lambda * in_pos2[0] +
w1lambda * in_pos2[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos2[h_id * in_img_w * num_channels] +
w1lambda * in_pos2[h_id * in_img_w * num_channels +
w_id * num_channels]));
}
}
}
template <typename T>
__global__ void KeTrilinearInterpBw(
T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, const T* out,
const size_t out_img_d, const size_t out_img_h, const size_t out_img_w,
const size_t output_h, const size_t output_w, const size_t num_channels,
const T ratio_d, const T ratio_h, const T ratio_w, const bool align_corners,
const int align_mode, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idt, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w;
out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h;
out_img_idx = tid % out_img_w;
} else {
out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels);
out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) /
(out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idt = align_flag
? static_cast<int>(ratio_d * (out_img_idt + 0.5) - 0.5)
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
(in_img_idt * in_img_h + in_img_idy) * in_img_w +
in_img_idx;
T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w;
T* in_pos2 = &in[in_pos2_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w];
// trilinear interpolation grad
platform::CudaAtomicAdd(&in_pos1[0],
d2lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[w_id],
d2lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w],
d2lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w + w_id],
d2lambda * h1lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[0],
d1lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[w_id],
d1lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w],
d1lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w + w_id],
d1lambda * h1lambda * w1lambda * out_pos[0]);
} else {
int in_pos1_idx = out_id_h * input_w +
in_img_idt * in_img_h * in_img_w * num_channels +
in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id;
T* in_pos1 = &in[in_pos1_idx];
int in_pos2_idx = in_pos1_idx + d_id * in_img_h * in_img_w * num_channels;
T* in_pos2 = &in[in_pos2_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w];
// trilinear interpolation grad
platform::CudaAtomicAdd(&in_pos1[0],
d2lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[w_id * num_channels],
d2lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos1[h_id * in_img_w * num_channels],
d2lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(
&in_pos1[h_id * in_img_w * num_channels + w_id * num_channels],
d2lambda * h1lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[0],
d1lambda * h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[w_id * num_channels],
d1lambda * h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos2[h_id * in_img_w * num_channels],
d1lambda * h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(
&in_pos2[h_id * in_img_w * num_channels + w_id * num_channels],
d1lambda * h1lambda * w1lambda * out_pos[0]);
}
}
}
template <typename T>
__device__ __forceinline__ static T Kecubic_interp(const T x0, const T x1,
const T x2, const T x3,
T t) {
T coeffs[4];
T a = -0.75;
T x_1 = t;
T x_2 = 1.0 - t;
coeffs[0] = cubic_convolution2<T>(x_1 + 1.0, a);
coeffs[1] = cubic_convolution1<T>(x_1, a);
coeffs[2] = cubic_convolution1<T>(x_2, a);
coeffs[3] = cubic_convolution2<T>(x_2 + 1.0, a);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename T>
__global__ void KeBicubicInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x;
T coefficients[4];
const T* in_pos_0;
const T* in_pos_1;
const T* in_pos_2;
const T* in_pos_3;
int access_x_0;
if (data_layout == DataLayout::kNCHW) {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>(in_img_h - 1)), 0);
access_x_0 = max(min(input_x - 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>(in_img_w - 1)), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>(in_img_w - 1)), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>(in_img_w - 1)), 0);
in_pos_0 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_0];
in_pos_1 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_1];
in_pos_2 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_2];
in_pos_3 = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x_3];
coefficients[k] = Kecubic_interp<T>(in_pos_0[0], in_pos_1[0],
in_pos_2[0], in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] =
Kecubic_interp<T>(coefficients[0], coefficients[1], coefficients[2],
coefficients[3], y_t);
} else {
for (int k = 0; k < 4; k++) {
int access_y =
max(min(input_y - 1 + k, static_cast<int>((in_img_h - 1))), 0);
int access_x_0 =
max(min(input_x - 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_1 =
max(min(input_x + 0, static_cast<int>((in_img_w - 1))), 0);
int access_x_2 =
max(min(input_x + 1, static_cast<int>((in_img_w - 1))), 0);
int access_x_3 =
max(min(input_x + 2, static_cast<int>((in_img_w - 1))), 0);
const T* in_pos_0 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_0 * num_channels + channel_id];
const T* in_pos_1 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_1 * num_channels + channel_id];
const T* in_pos_2 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_2 * num_channels + channel_id];
const T* in_pos_3 =
&in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x_3 * num_channels + channel_id];
coefficients[k] = Kecubic_interp(in_pos_0[0], in_pos_1[0], in_pos_2[0],
in_pos_3[0], x_t);
}
out[out_id_h * output_w + out_id_w] =
static_cast<T>(Kecubic_interp(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t));
}
}
}
template <typename T>
__global__ void KeBicubicInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
T in_img_idy = align_corners
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y;
T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x;
T x_coeffs[4];
T y_coeffs[4];
get_cubic_upsample_coefficients(x_coeffs, x_t);
get_cubic_upsample_coefficients(y_coeffs, y_t);
const T* out_pos = &out[out_id_h * output_w + out_id_w];
T* in_pos;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
int access_y = max(min(static_cast<int>(input_y - 1 + j),
static_cast<int>(in_img_h - 1)),
0);
int access_x = max(min(static_cast<int>(input_x - 1 + i),
static_cast<int>(in_img_w - 1)),
0);
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
access_y * in_img_w + access_x];
} else {
in_pos = &in[out_id_h * input_w + access_y * in_img_w * num_channels +
access_x * num_channels + channel_id];
}
platform::CudaAtomicAdd(&in_pos[0],
(out_pos[0] * y_coeffs[j] * x_coeffs[i]));
}
}
}
}
template <typename T>
static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_w = new_size[0];
} else {
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 0) {
scale_w = scale[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_w = size_data[0];
}
}
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_w};
} else {
dim_out = {n, out_w, c};
}
auto output_data = output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1.0) / (out_w - 1.0)
: static_cast<float>(in_w) / out_w;
}
int in_cw = c * in_w;
int out_cw = c * out_w;
int pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("linear" == interp_method) {
KeLinearInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_w, in_cw, output_data, out_w, n, out_cw, c, ratio_w,
align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_h = new_size[0];
out_w = new_size[1];
} else {
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 1) {
scale_w = scale[1];
scale_h = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_w > 0. && scale_h > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
}
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_h, out_w};
} else {
dim_out = {n, out_h, out_w, c};
}
auto output_data = output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
int pixelNum = n * out_chw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpFw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
}
template <typename T>
static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
} else {
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_d = scale_data[0];
scale_h = scale_data[1];
scale_w = scale_data[2];
} else {
scale_d = scale_data[0];
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
}
PADDLE_ENFORCE_GT(out_d, 0, platform::errors::InvalidArgument(
"out_d in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_d, out_h, out_w};
} else {
dim_out = {n, out_d, out_h, out_w, c};
}
auto output_data = output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_dhw = in_d * in_h * in_w;
int out_dhw = out_d * out_h * out_w;
int in_cdhw = c * in_dhw;
int out_cdhw = c * out_dhw;
int pixelNum = n * out_cdhw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("trilinear" == interp_method) {
KeTrilinearInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout);
}
}
template <typename T>
static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 0) {
scale_w = scale[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_w = size_data[0];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_w = new_size[0];
}
auto* output_grad_data = output_grad.data<T>();
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_w};
} else {
dim_grad = {n, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto* input_grad_data = input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_cw = c * in_w;
int out_cw = c * out_w;
int pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("linear" == interp_method) {
KeLinearInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_w, in_cw, output_grad_data, out_w, n, out_cw, c,
ratio_w, align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 1) {
scale_w = scale[1];
scale_h = scale[0];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_w > 0. && scale_h > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
}
auto* output_grad_data = output_grad.data<T>();
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_h, in_w};
} else {
dim_grad = {n, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto* input_grad_data = input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
int pixelNum = n * out_chw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
data_layout);
} else if ("bicubic" == interp_method) {
KeBicubicInterpBw<
T><<<config.blocks, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
}
}
template <typename T>
static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad,
const Tensor& output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_d = scale_data[0];
scale_h = scale_data[1];
scale_w = scale_data[2];
} else {
scale_d = scale_data[0];
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_d = size_data[0];
out_h = size_data[1];
out_w = size_data[2];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
}
auto* output_grad_data = output_grad.data<T>();
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_d, in_h, in_w};
} else {
dim_grad = {n, in_d, in_h, in_w, c};
}
auto* input_grad_data = input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_dhw = in_d * in_h * in_w;
int out_dhw = out_d * out_h * out_w;
int in_cdhw = c * in_dhw;
int out_cdhw = c * out_dhw;
int pixelNum = n * out_cdhw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("trilinear" == interp_method) {
KeTrilinearInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
align_mode, data_layout);
}
}
template <typename T>
class InterpolateOpV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::NotFound("This kernel only runs on GPU device."));
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
if (input_dims.size() == 3) { // 1D interpolation
Interpolate1DCUDAFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDAFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDAFwd<T>(ctx, *input, output);
}
}
};
template <typename T>
class InterpolateV2GradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::NotFound("This kernel only runs on GPU device."));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto output_grad_dims = output_grad->dims();
if (output_grad_dims.size() == 3) { // 1D interpolation
Interpolate1DCUDABwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDABwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDABwd<T>(ctx, input_grad, *output_grad);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(bilinear_interp_v2,
ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(bilinear_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(nearest_interp_v2,
ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(nearest_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(trilinear_interp_v2,
ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(trilinear_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(linear_interp_v2, ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(linear_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(bicubic_interp_v2,
ops::InterpolateOpV2CUDAKernel<float>,
ops::InterpolateOpV2CUDAKernel<double>,
ops::InterpolateOpV2CUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(bicubic_interp_v2_grad,
ops::InterpolateV2GradOpCUDAKernel<float>,
ops::InterpolateV2GradOpCUDAKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
inline std::vector<int> get_new_shape(
const std::vector<const Tensor*>& list_new_shape_tensor) {
// get tensor from
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}),
platform::errors::InvalidArgument("shape of dim tensor should be [1]"));
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_shape;
}
template <typename T>
inline std::vector<T> get_new_data_from_tensor(const Tensor* new_data_tensor) {
std::vector<T> vec_new_data;
auto* new_data = new_data_tensor->data<T>();
framework::Tensor cpu_starts_tensor;
if (platform::is_gpu_place(new_data_tensor->place())) {
TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
inline void ExtractNCDWH(const framework::DDim& dims,
const DataLayout& data_layout, int* N, int* C, int* D,
int* H, int* W) {
*N = dims[0];
if (dims.size() == 3) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[2];
*D = 1;
*H = 1;
*W = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
} else if (dims.size() == 4) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3];
*D = 1;
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
} else {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[4];
*D = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*H = data_layout == DataLayout::kNCHW ? dims[3] : dims[2];
*W = data_layout == DataLayout::kNCHW ? dims[4] : dims[3];
}
}
template <typename T>
static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int n, const int c,
const int out_h, const int out_w,
const bool align_corners,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
} else {
output_t(i, k, l, j) = input_t(i, in_k, in_l, j);
}
}
}
}
}
}
template <typename T>
static void LinearInterpolation(const Tensor& input, Tensor* output,
const float ratio_w, const int in_w,
const int n, const int c, const int out_w,
const bool align_corners, const bool align_mode,
const DataLayout data_layout) {
auto input_t = EigenTensor<T, 3>::From(input);
auto output_t = EigenTensor<T, 3>::From(*output);
bool align_flag = (align_mode == 0 && !align_corners);
std::vector<int> vx_w, vx_e;
std::vector<float> vd_w, vd_e;
vx_w.reserve(out_w);
vx_e.reserve(out_w);
vd_w.reserve(out_w);
vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int l = 0; l < out_w; l++) {
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0; // w
int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w; // w_id
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; // w1lambda
float d_e = 1.f - d_w; // w2lambda
{
vx_w[l] = x_w;
vx_e[l] = x_e;
vd_w[l] = d_w;
vd_e[l] = d_e;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
for (int l = 0; l < out_w; l++) {
// linear interpolation
T out_t;
if (data_layout == DataLayout::kNCHW) {
out_t = input_t(i, j, vx_w[l]) * vd_e[l] +
input_t(i, j, vx_e[l]) * vd_w[l];
output_t(i, j, l) = out_t;
} else {
out_t = input_t(i, vx_w[l], j) * vd_e[l] +
input_t(i, vx_e[l], j) * vd_w[l];
output_t(i, l, j) = out_t;
}
}
}
}
}
template <typename T>
static void LinearInterpolationGrad(const Tensor& output_grad,
Tensor* input_grad, const float ratio_w,
const int in_w, const int n, const int c,
const int out_w, const bool align_corners,
const int align_mode,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 3>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 3>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int l = 0; l < out_w; l++) {
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0; // w
int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w; // w_id
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; // w1lambda
float d_e = 1.f - d_w; // w2lambda
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// linear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(i, j, l);
input_grad_t(i, j, x_w) += static_cast<T>(grad * d_e);
input_grad_t(i, j, x_e) += static_cast<T>(grad * d_w);
} else {
const T grad = output_grad_t(i, l, j);
input_grad_t(i, x_w, j) += static_cast<T>(grad * d_e);
input_grad_t(i, x_e, j) += static_cast<T>(grad * d_w);
}
}
}
}
}
template <typename T>
static void BilinearInterpolation(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w,
const bool align_corners,
const bool align_mode,
const DataLayout data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
bool align_flag = (align_mode == 0 && !align_corners);
std::vector<int> vy_n, vy_s;
std::vector<float> vd_n, vd_s;
vy_n.reserve(out_h);
vy_s.reserve(out_h);
vd_n.reserve(out_h);
vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int k = 0; k < out_h; k++) {
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
{
vy_n[k] = y_n;
vy_s[k] = y_s;
vd_n[k] = d_n;
vd_s[k] = d_s;
}
}
std::vector<int> vx_w, vx_e;
std::vector<float> vd_w, vd_e;
vx_w.reserve(out_w);
vx_e.reserve(out_w);
vd_w.reserve(out_w);
vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int l = 0; l < out_w; l++) {
int x_w = (align_mode == 0 && !align_corners)
? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
{
vx_w[l] = x_w;
vx_e[l] = x_e;
vd_w[l] = d_w;
vd_e[l] = d_e;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(4)
#endif
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
for (int k = 0; k < out_h; k++) { // loop for images
for (int l = 0; l < out_w; l++) {
// bilinear interpolation
T out_t;
if (data_layout == DataLayout::kNCHW) {
out_t = input_t(i, j, vy_n[k], vx_w[l]) * vd_s[k] * vd_e[l] +
input_t(i, j, vy_s[k], vx_w[l]) * vd_n[k] * vd_e[l] +
input_t(i, j, vy_n[k], vx_e[l]) * vd_s[k] * vd_w[l] +
input_t(i, j, vy_s[k], vx_e[l]) * vd_n[k] * vd_w[l];
output_t(i, j, k, l) = out_t;
} else {
out_t = input_t(i, vy_n[k], vx_w[l], j) * vd_s[k] * vd_e[l] +
input_t(i, vy_s[k], vx_w[l], j) * vd_n[k] * vd_e[l] +
input_t(i, vy_n[k], vx_e[l], j) * vd_s[k] * vd_w[l] +
input_t(i, vy_s[k], vx_e[l], j) * vd_n[k] * vd_w[l];
output_t(i, k, l, j) = out_t;
}
}
}
}
}
}
template <typename T>
static void TrilinearInterpolation(
const Tensor& input, Tensor* output, const float ratio_d,
const float ratio_h, const float ratio_w, const int in_d, const int in_h,
const int in_w, const int n, const int c, const int out_d, const int out_h,
const int out_w, const bool align_corners, const bool align_mode,
const DataLayout& data_layout) {
auto input_t = EigenTensor<T, 5>::From(input);
auto output_t = EigenTensor<T, 5>::From(*output);
bool align_flag = (align_mode == 0 && !align_corners);
std::vector<int> vt_f, vt_b;
std::vector<float> vd_f, vd_b;
vt_f.reserve(out_d);
vt_b.reserve(out_d);
vd_f.reserve(out_d);
vd_b.reserve(out_d);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int j = 0; j < out_d; j++) {
int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
: static_cast<int>(ratio_d * j);
t_f = (t_f > 0) ? t_f : 0;
int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
float idx_src_t = ratio_d * (j + 0.5) - 0.5;
idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
float d_b = 1.f - d_f;
{
vt_f[j] = t_f;
vt_b[j] = t_b;
vd_f[j] = d_f;
vd_b[j] = d_b;
}
}
std::vector<int> vy_n, vy_s;
std::vector<float> vd_n, vd_s;
vy_n.reserve(out_h);
vy_s.reserve(out_h);
vd_n.reserve(out_h);
vd_s.reserve(out_h);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int k = 0; k < out_h; k++) {
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
{
vy_n[k] = y_n;
vy_s[k] = y_s;
vd_n[k] = d_n;
vd_s[k] = d_s;
}
}
std::vector<int> vx_w, vx_e;
std::vector<float> vd_w, vd_e;
vx_w.reserve(out_w);
vx_e.reserve(out_w);
vd_w.reserve(out_w);
vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int l = 0; l < out_w; l++) {
int x_w = (align_mode == 0 && !align_corners)
? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
{
vx_w[l] = x_w;
vx_e[l] = x_e;
vd_w[l] = d_w;
vd_e[l] = d_e;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(5)
#endif
for (int b = 0; b < n; b++) { // loop for batches
for (int i = 0; i < c; i++) { // loop for channels
for (int j = 0; j < out_d; j++) { // loop for D, H, W
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
// trilinear interpolation
if (data_layout == DataLayout::kNCHW) {
T out_t = input_t(b, i, vt_f[j], vy_n[k], vx_w[l]) * vd_b[j] *
vd_s[k] * vd_e[l] +
input_t(b, i, vt_f[j], vy_n[k], vx_e[l]) * vd_b[j] *
vd_s[k] * vd_w[l] +
input_t(b, i, vt_f[j], vy_s[k], vx_w[l]) * vd_b[j] *
vd_n[k] * vd_e[l] +
input_t(b, i, vt_f[j], vy_s[k], vx_e[l]) * vd_b[j] *
vd_n[k] * vd_w[l] +
input_t(b, i, vt_b[j], vy_n[k], vx_w[l]) * vd_f[j] *
vd_s[k] * vd_e[l] +
input_t(b, i, vt_b[j], vy_n[k], vx_e[l]) * vd_f[j] *
vd_s[k] * vd_w[l] +
input_t(b, i, vt_b[j], vy_s[k], vx_w[l]) * vd_f[j] *
vd_n[k] * vd_e[l] +
input_t(b, i, vt_b[j], vy_s[k], vx_e[l]) * vd_f[j] *
vd_n[k] * vd_w[l];
output_t(b, i, j, k, l) = out_t;
} else {
T out_t = input_t(b, vt_f[j], vy_n[k], vx_w[l], i) * vd_b[j] *
vd_s[k] * vd_e[l] +
input_t(b, vt_f[j], vy_n[k], vx_e[l], i) * vd_b[j] *
vd_s[k] * vd_w[l] +
input_t(b, vt_f[j], vy_s[k], vx_w[l], i) * vd_b[j] *
vd_n[k] * vd_e[l] +
input_t(b, vt_f[j], vy_s[k], vx_e[l], i) * vd_b[j] *
vd_n[k] * vd_w[l] +
input_t(b, vt_b[j], vy_n[k], vx_w[l], i) * vd_f[j] *
vd_s[k] * vd_e[l] +
input_t(b, vt_b[j], vy_n[k], vx_e[l], i) * vd_f[j] *
vd_s[k] * vd_w[l] +
input_t(b, vt_b[j], vy_s[k], vx_w[l], i) * vd_f[j] *
vd_n[k] * vd_e[l] +
input_t(b, vt_b[j], vy_s[k], vx_e[l], i) * vd_f[j] *
vd_n[k] * vd_w[l];
output_t(b, j, k, l, i) = out_t;
}
}
}
}
}
}
}
template <typename T>
HOSTDEVICE inline T cubic_convolution1(T x, T A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename T>
HOSTDEVICE inline T cubic_convolution2(T x, T A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename T>
HOSTDEVICE inline void get_cubic_upsample_coefficients(T coeffs[4], T t) {
T A = -0.75;
T x1 = t;
coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A);
coeffs[1] = cubic_convolution1<T>(x1, A);
// opposite coefficients
T x2 = 1.0 - t;
coeffs[2] = cubic_convolution1<T>(x2, A);
coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);
}
template <typename T>
static inline T cubic_interp(T x0, T x1, T x2, T x3, T t) {
T coeffs[4];
get_cubic_upsample_coefficients<T>(coeffs, t);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename T>
static void BicubicInterpolation(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = floorf(y_n);
const T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5);
int input_x = floorf(x_n);
const T x_t = x_n - input_x;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
T coefficients[4];
// interp 4 times in x direction
for (int ii = 0; ii < 4; ii++) {
int access_y = std::max(std::min(input_y - 1 + ii, in_h - 1),
static_cast<int>(0));
int access_x_0 =
std::max(std::min(input_x - 1, in_w - 1), static_cast<int>(0));
int access_x_1 =
std::max(std::min(input_x + 0, in_w - 1), static_cast<int>(0));
int access_x_2 =
std::max(std::min(input_x + 1, in_w - 1), static_cast<int>(0));
int access_x_3 =
std::max(std::min(input_x + 2, in_w - 1), static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) {
coefficients[ii] =
cubic_interp<T>(input_t(i, j, access_y, access_x_0),
input_t(i, j, access_y, access_x_1),
input_t(i, j, access_y, access_x_2),
input_t(i, j, access_y, access_x_3), x_t);
} else {
coefficients[ii] =
cubic_interp<T>(input_t(i, access_y, access_x_0, j),
input_t(i, access_y, access_x_1, j),
input_t(i, access_y, access_x_2, j),
input_t(i, access_y, access_x_3, j), x_t);
}
}
// interp y direction
if (data_layout == DataLayout::kNCHW) {
output_t(i, j, k, l) =
cubic_interp<T>(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t);
} else {
output_t(i, k, l, j) =
cubic_interp<T>(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], y_t);
}
}
}
}
}
}
template <typename T>
static void NearestNeighborInterpolateGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
const float ratio_w, const int n, const int c, const int out_h,
const int out_w, const bool align_corners, const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
if (data_layout == DataLayout::kNCHW) {
input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
} else {
input_grad_t(i, in_k, in_l, j) += output_grad_t(i, k, l, j);
}
}
}
}
}
}
template <typename T>
static void BilinearInterpolationGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h, const int in_w, const int n,
const int c, const int out_h, const int out_w, const bool align_corners,
const int align_mode, const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int k = 0; k < out_h; k++) { // loop for images
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) {
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bilinear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
} else {
const T grad = output_grad_t(i, k, l, j);
input_grad_t(i, y_n, x_w, j) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, y_s, x_w, j) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, y_n, x_e, j) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, y_s, x_e, j) += static_cast<T>(grad * d_n * d_w);
}
}
}
}
}
}
template <typename T>
static void TrilinearInterpolationGrad(
const Tensor& output_grad, Tensor* input_grad, const float ratio_d,
const float ratio_h, const float ratio_w, const int in_d, const int in_h,
const int in_w, const int n, const int c, const int out_d, const int out_h,
const int out_w, const bool align_corners, const int align_mode,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 5>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 5>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int j = 0; j < out_d; j++) { // loop for D
int t_f = align_flag ? static_cast<int>(ratio_d * (j + 0.5) - 0.5)
: static_cast<int>(ratio_d * j);
t_f = (t_f > 0) ? t_f : 0;
int t_b = (t_f + 1) < (in_d - 1) ? (t_f + 1) : (in_d - 1);
float idx_src_t = ratio_d * (j + 0.5) - 0.5;
idx_src_t = (idx_src_t > 0) ? idx_src_t : 0;
float d_f = align_flag ? idx_src_t - t_f : ratio_d * j - t_f;
float d_b = 1.f - d_f;
for (int k = 0; k < out_h; k++) { // loop for H
int y_n = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
y_n = (y_n > 0) ? y_n : 0;
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) { // loop for W
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0;
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int b = 0; b < n; b++) { // loop for batches
for (int i = 0; i < c; i++) { // loop for channels
// trilinear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(b, i, j, k, l);
input_grad_t(b, i, t_f, y_n, x_w) +=
static_cast<T>(grad * d_b * d_s * d_e);
input_grad_t(b, i, t_f, y_n, x_e) +=
static_cast<T>(grad * d_b * d_s * d_w);
input_grad_t(b, i, t_f, y_s, x_w) +=
static_cast<T>(grad * d_b * d_n * d_e);
input_grad_t(b, i, t_f, y_s, x_e) +=
static_cast<T>(grad * d_b * d_n * d_w);
input_grad_t(b, i, t_b, y_n, x_w) +=
static_cast<T>(grad * d_f * d_s * d_e);
input_grad_t(b, i, t_b, y_n, x_e) +=
static_cast<T>(grad * d_f * d_s * d_w);
input_grad_t(b, i, t_b, y_s, x_w) +=
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, i, t_b, y_s, x_e) +=
static_cast<T>(grad * d_f * d_n * d_w);
} else {
const T grad = output_grad_t(b, j, k, l, i);
input_grad_t(b, t_f, y_n, x_w, i) +=
static_cast<T>(grad * d_b * d_s * d_e);
input_grad_t(b, t_f, y_n, x_e, i) +=
static_cast<T>(grad * d_b * d_s * d_w);
input_grad_t(b, t_f, y_s, x_w, i) +=
static_cast<T>(grad * d_b * d_n * d_e);
input_grad_t(b, t_f, y_s, x_e, i) +=
static_cast<T>(grad * d_b * d_n * d_w);
input_grad_t(b, t_b, y_n, x_w, i) +=
static_cast<T>(grad * d_f * d_s * d_e);
input_grad_t(b, t_b, y_n, x_e, i) +=
static_cast<T>(grad * d_f * d_s * d_w);
input_grad_t(b, t_b, y_s, x_w, i) +=
static_cast<T>(grad * d_f * d_n * d_e);
input_grad_t(b, t_b, y_s, x_e, i) +=
static_cast<T>(grad * d_f * d_n * d_w);
}
}
}
}
}
}
}
template <typename T>
static void BicubicInterpolationGrad(const Tensor& output_grad,
Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h,
const int in_w, const int n, const int c,
const int out_h, const int out_w,
const bool align_corners,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
T y_n = align_corners ? static_cast<T>(ratio_h * k)
: static_cast<T>(ratio_h * (k + 0.5) - 0.5);
int input_y = floorf(y_n);
T y_t = y_n - input_y;
for (int l = 0; l < out_w; l++) {
T x_n = align_corners ? static_cast<T>(ratio_w * l)
: static_cast<T>(ratio_w * (l + 0.5) - 0.5);
int input_x = floorf(x_n);
T x_t = x_n - input_x;
T x_coeffs[4];
T y_coeffs[4];
get_cubic_upsample_coefficients<T>(x_coeffs, x_t);
get_cubic_upsample_coefficients<T>(y_coeffs, y_t);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bicubic interpolation grad
for (int ii = 0; ii < 4; ii++) {
for (int jj = 0; jj < 4; jj++) {
int access_x = std::max(std::min(input_x - 1 + ii, in_w - 1),
static_cast<int>(0));
int access_y = std::max(std::min(input_y - 1 + jj, in_h - 1),
static_cast<int>(0));
if (data_layout == DataLayout::kNCHW) {
T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, access_y, access_x) +=
grad * y_coeffs[jj] * x_coeffs[ii];
} else {
T grad = output_grad_t(i, k, l, j);
input_grad_t(i, access_y, access_x, j) +=
grad * y_coeffs[jj] * x_coeffs[ii];
}
}
}
}
}
}
}
}
template <typename T>
static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_w = new_size[0];
} else {
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 0) {
scale_w = scale[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_w = out_size_data[0];
}
}
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_w};
} else {
dim_out = {n, out_w, c};
}
output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("linear" == interp_method) {
LinearInterpolation<T>(input, output, ratio_w, in_w, n, c, out_w,
align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
} else {
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 1) {
scale_h = scale[0];
scale_w = scale[1];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_h > 0. && scale_w > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
}
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_h, out_w};
} else {
dim_out = {n, out_h, out_w, c};
}
output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("bilinear" == interp_method) {
BilinearInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
out_h, out_w, align_corners, align_mode,
data_layout);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolate<T>(input, output, ratio_h, ratio_w, n, c, out_h,
out_w, align_corners, data_layout);
} else if ("bicubic" == interp_method) {
BicubicInterpolation<T>(input, output, ratio_h, ratio_w, in_h, in_w, n, c,
out_h, out_w, align_corners, data_layout);
}
}
template <typename T>
static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
} else {
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_d = scale_data[0];
scale_h = scale_data[1];
scale_w = scale_data[2];
} else {
scale_d = scale_data[0];
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_w > 0. && scale_h > 0. && scale_d > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
}
}
PADDLE_ENFORCE_GT(out_d, 0, platform::errors::InvalidArgument(
"out_d in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_d, out_h, out_w};
} else {
dim_out = {n, out_d, out_h, out_w, c};
}
output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("trilinear" == interp_method) {
TrilinearInterpolation<T>(input, output, ratio_d, ratio_h, ratio_w, in_d,
in_h, in_w, n, c, out_d, out_h, out_w,
align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate1DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor& output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
float scale_w = -1.0;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 0) {
scale_w = scale[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_w > 0.) {
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_w = out_size_data[0];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_w = new_size[0];
}
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_w};
} else {
dim_grad = {n, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("linear" == interp_method) {
LinearInterpolationGrad<T>(output_grad, input_grad, ratio_w, in_w, n, c,
out_w, align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor& output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_w = scale_data[0];
scale_h = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 1) {
scale_h = scale[0];
scale_w = scale[1];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_h > 0. && scale_w > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
}
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_h, in_w};
} else {
dim_grad = {n, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
in_h, in_w, n, c, out_h, out_w, align_corners,
align_mode, data_layout);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolateGrad<T>(output_grad, input_grad, ratio_h, ratio_w,
n, c, out_h, out_w, align_corners,
data_layout);
} else if ("bicubic" == interp_method) {
BicubicInterpolationGrad<T>(output_grad, input_grad, ratio_h, ratio_w, in_h,
in_w, n, c, out_h, out_w, align_corners,
data_layout);
}
}
template <typename T>
static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_d = ctx.Attr<int>("out_d");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_d = -1;
float scale_h = -1;
float scale_w = -1;
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_d = scale_data[0];
scale_h = scale_data[1];
scale_w = scale_data[2];
} else {
scale_d = scale_data[0];
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
if (scale.size() > 1) {
scale_d = scale[0];
scale_h = scale[1];
scale_w = scale[2];
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
}
if (scale_d > 0. && scale_h > 0. && scale_w > 0.) {
out_d = static_cast<int>(in_d * scale_d);
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_d = out_size_data[0];
out_h = out_size_data[1];
out_w = out_size_data[2];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_d = new_size[0];
out_h = new_size[1];
out_w = new_size[2];
}
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_d, in_h, in_w};
} else {
dim_grad = {n, in_d, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_d == out_d && in_h == out_h && in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_d = 0.f;
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_d > 1) {
ratio_d = (align_corners) ? static_cast<float>(in_d - 1) / (out_d - 1)
: static_cast<float>(in_d) / out_d;
}
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("trilinear" == interp_method) {
TrilinearInterpolationGrad<T>(
output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n,
c, out_d, out_h, out_w, align_corners, align_mode, data_layout);
}
}
template <typename T>
class InterpolateV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
if (input_dims.size() == 3) { // 1D interpolation
Interpolate1DCPUFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCPUFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCPUFwd<T>(ctx, *input, output);
}
}
};
template <typename T>
class InterpolateV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto output_grad_dims = output_grad->dims();
if (output_grad_dims.size() == 3) { // 1D interpolation grad
Interpolate1DCPUBwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 4) { // 2D interpolation grad
Interpolate2DCPUBwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 5) { // 3D interpolation grad
Interpolate3DCPUBwd<T>(ctx, input_grad, *output_grad);
}
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle
from paddle.fluid import Program, program_guard
from paddle.nn.functional import interpolate
def cubic_1(x, a):
return ((a + 2) * x - (a + 3)) * x * x + 1
def cubic_2(x, a):
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a
def cubic_interp1d(x0, x1, x2, x3, t):
param = [0, 0, 0, 0]
a = -0.75
x_1 = t
x_2 = 1.0 - t
param[0] = cubic_2(x_1 + 1.0, a)
param[1] = cubic_1(x_1, a)
param[2] = cubic_1(x_2, a)
param[3] = cubic_2(x_2 + 1.0, a)
return x0 * param[0] + x1 * param[1] + x2 * param[2] + x3 * param[3]
def value_bound(input, w, h, x, y):
access_x = int(max(min(x, w - 1), 0))
access_y = int(max(min(y, h - 1), 0))
return input[:, :, access_y, access_x]
def bicubic_interp_np(input,
out_h,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
data_layout='kNCHW'):
"""trilinear interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
batch_size, channel, in_h, in_w = input.shape
ratio_h = ratio_w = 0.0
if out_h > 1:
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0)
else:
ratio_h = 1.0 * in_h / out_h
if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((batch_size, channel, out_h, out_w))
for k in range(out_h):
if (align_corners):
h = ratio_h * k
else:
h = ratio_h * (k + 0.5) - 0.5
input_y = np.floor(h)
y_t = h - input_y
for l in range(out_w):
if (align_corners):
w = ratio_w * l
else:
w = ratio_w * (l + 0.5) - 0.5
input_x = np.floor(w)
x_t = w - input_x
for i in range(batch_size):
for j in range(channel):
coefficients = [0, 0, 0, 0]
for ii in range(4):
access_x_0 = int(max(min(input_x - 1, in_w - 1), 0))
access_x_1 = int(max(min(input_x + 0, in_w - 1), 0))
access_x_2 = int(max(min(input_x + 1, in_w - 1), 0))
access_x_3 = int(max(min(input_x + 2, in_w - 1), 0))
access_y = int(max(min(input_y - 1 + ii, in_h - 1), 0))
coefficients[ii] = cubic_interp1d(
input[i, j, access_y, access_x_0],
input[i, j, access_y, access_x_1],
input[i, j, access_y, access_x_2],
input[i, j, access_y, access_x_3], x_t)
out[i, j, k, l] = cubic_interp1d(
coefficients[0], coefficients[1], coefficients[2],
coefficients[3], y_t)
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
return out.astype(input.dtype)
class TestBicubicInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "bicubic_interp_v2"
input_np = np.random.random(self.input_shape).astype("float64")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0.:
scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1]
scale_h = self.scale[0]
out_h = int(in_h * scale_h)
out_w = int(in_w * scale_w)
else:
out_h = self.out_h
out_w = self.out_w
output_np = bicubic_interp_np(input_np, out_h, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'data_layout': self.data_layout
}
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0.:
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [2, 3, 5, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
class TestBicubicInterpCase1(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.align_corners = True
class TestBicubicInterpCase2(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [3, 3, 9, 6]
self.out_h = 10
self.out_w = 8
self.scale = 0.
self.align_corners = True
class TestBicubicInterpCase3(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.align_corners = False
class TestBicubicInterpCase4(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2]).astype("int32")
self.align_corners = True
class TestBicubicInterpCase5(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [3, 3, 9, 6]
self.out_h = 11
self.out_w = 11
self.scale = 0.
self.out_size = np.array([6, 4]).astype("int32")
self.align_corners = False
class TestBicubicInterpCase6(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0
self.out_size = np.array([64, 32]).astype("int32")
self.align_corners = False
class TestBicubicInterpSame(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = 0.
self.align_corners = True
class TestBicubicInterpScale(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = [1., 1.]
self.align_corners = True
class TestBicubicInterpDataLayout(TestBicubicInterpOp):
def init_test_case(self):
self.interp_method = 'bicubic'
self.input_shape = [2, 5, 5, 3]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
self.data_layout = "NHWC"
class TestBicubicInterpOpAPI(unittest.TestCase):
def test_case(self):
np.random.seed(200)
x_data = np.random.random((2, 3, 6, 6)).astype("float32")
dim_data = np.array([12]).astype("int32")
shape_data = np.array([12, 12]).astype("int32")
actual_size_data = np.array([12, 12]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
dim = fluid.data(name="dim", shape=[1], dtype="int32")
shape_tensor = fluid.data(
name="shape_tensor", shape=[2], dtype="int32")
actual_size = fluid.data(
name="actual_size", shape=[2], dtype="int32")
scale_tensor = fluid.data(
name="scale_tensor", shape=[1], dtype="float32")
out1 = interpolate(
x, size=[12, 12], mode='bicubic', align_corners=False)
out2 = interpolate(
x, size=[12, dim], mode='bicubic', align_corners=False)
out3 = interpolate(
x, size=shape_tensor, mode='bicubic', align_corners=False)
out4 = interpolate(
x, size=[12, 12], mode='bicubic', align_corners=False)
out5 = interpolate(
x,
scale_factor=scale_tensor,
mode='bicubic',
align_corners=False)
out6 = interpolate(
x, scale_factor=2.0, mode='bicubic', align_corners=False)
out7 = interpolate(
x, scale_factor=[2.0, 2.0], mode='bicubic', align_corners=False)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(
fluid.default_main_program(),
feed={
"x": x_data,
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5, out6, out7],
return_numpy=True)
expect_res = bicubic_interp_np(
x_data, out_h=12, out_w=12, align_corners=False)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(x_data)
interp = interpolate(
x, size=[12, 12], mode='bicubic', align_corners=False)
dy_result = interp.numpy()
expect = bicubic_interp_np(
x_data, out_h=12, out_w=12, align_corners=False)
self.assertTrue(np.allclose(dy_result, expect))
class TestBicubicOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of interpoalte must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
self.assertRaises(TypeError, interpolate, x1)
def test_mode_type():
# mode must be "BILINEAR" "TRILINEAR" "NEAREST" "BICUBIC"
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x, size=[12, 12], mode='UNKONWN', align_corners=False)
def test_input_shape():
x = fluid.data(name="x", shape=[2], dtype="float32")
out = interpolate(
x, size=[12, 12], mode='BICUBIC', align_corners=False)
def test_align_corcers():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
interpolate(x, size=[12, 12], mode='BICUBIC', align_corners=3)
def test_out_shape():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x, size=[12], mode='bicubic', align_corners=False)
def test_attr_data_format():
# for 5-D input, data_format only can be NCDHW or NDHWC
input = fluid.data(
name="input", shape=[2, 3, 6, 9, 4], dtype="float32")
out = interpolate(
input,
size=[4, 8, 4, 5],
mode='trilinear',
data_format='NHWC')
def test_actual_shape():
# the actual_shape must be Variable.
x = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
out = interpolate(
x, size=[12, 12], mode='BICUBIC', align_corners=False)
def test_scale_value():
# the scale must be greater than zero.
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
size=None,
mode='BICUBIC',
align_corners=False,
scale_factor=-2.0)
def test_attr_5D_input():
# for 5-D input, data_format only can be NCDHW or NDHWC
input = fluid.data(
name="input", shape=[2, 3, 6, 9, 4], dtype="float32")
out = interpolate(
input,
size=[4, 8, 4, 5],
mode='trilinear',
data_format='NDHWC')
def test_scale_type():
# the scale must be greater than zero.
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
scale = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
out = interpolate(
x,
size=None,
mode='bicubic',
align_corners=False,
scale_factor=scale)
def test_align_mode():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
size=None,
mode='nearest',
align_corners=False,
align_mode=2,
scale_factor=1.0)
def test_outshape_and_scale():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
size=None,
mode='bicubic',
align_corners=False,
scale_factor=None)
def test_align_corners_and_nearest():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
size=None,
mode='nearest',
align_corners=True,
scale_factor=None)
def test_scale_shape():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
size=None,
mode='nearest',
align_corners=False,
scale_factor=[1, 2, 2])
def test_scale_value():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x,
size=None,
mode='trilinear',
align_corners=False,
scale_factor=[1, 2, 2])
self.assertRaises(ValueError, test_mode_type)
self.assertRaises(ValueError, test_input_shape)
self.assertRaises(TypeError, test_align_corcers)
self.assertRaises(ValueError, test_attr_data_format)
self.assertRaises(TypeError, test_actual_shape)
self.assertRaises(ValueError, test_scale_value)
self.assertRaises(ValueError, test_out_shape)
self.assertRaises(ValueError, test_attr_5D_input)
self.assertRaises(TypeError, test_scale_type)
self.assertRaises(ValueError, test_align_mode)
self.assertRaises(ValueError, test_outshape_and_scale)
self.assertRaises(ValueError, test_align_corners_and_nearest)
self.assertRaises(ValueError, test_scale_shape)
self.assertRaises(ValueError, test_scale_value)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.nn.functional import interpolate
import paddle
def bilinear_interp_np(input,
out_h,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
align_mode=0,
data_layout='NCHW'):
"""bilinear interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
batch_size, channel, in_h, in_w = input.shape
ratio_h = ratio_w = 0.0
if out_h > 1:
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0)
else:
ratio_h = 1.0 * in_h / out_h
if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((batch_size, channel, out_h, out_w))
for i in range(out_h):
if (align_mode == 0 and not align_corners):
h = int(ratio_h * (i + 0.5) - 0.5)
else:
h = int(ratio_h * i)
h = max(0, h)
hid = 1 if h < in_h - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_h = max(ratio_h * (i + 0.5) - 0.5, 0)
h1lambda = idx_src_h - h
else:
h1lambda = ratio_h * i - h
h2lambda = 1.0 - h1lambda
for j in range(out_w):
if (align_mode == 0 and not align_corners):
w = int(ratio_w * (j + 0.5) - 0.5)
else:
w = int(ratio_w * j)
w = max(0, w)
wid = 1 if w < in_w - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_w = max(ratio_w * (j + 0.5) - 0.5, 0)
w1lambda = idx_src_w - w
else:
w1lambda = ratio_w * j - w
w2lambda = 1.0 - w1lambda
out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] +
w1lambda*input[:, :, h, w+wid]) + \
h1lambda*(w2lambda*input[:, :, h+hid, w] +
w1lambda*input[:, :, h+hid, w+wid])
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
return out.astype(input.dtype)
class TestBilinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "bilinear_interp_v2"
input_np = np.random.random(self.input_shape).astype("float64")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0.:
scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1]
scale_h = self.scale[0]
out_h = int(in_h * scale_h)
out_w = int(in_w * scale_w)
else:
out_h = self.out_h
out_w = self.out_w
output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode,
'data_layout': self.data_layout
}
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0.:
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpCase1(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpCase2(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpCase3(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpCase4(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpCase5(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = np.array([11, 11]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpCase6(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([65, 33]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpSame(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpActualShape(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpDataLayout(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 5, 5, 3]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
self.data_layout = "NHWC"
class TestBilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "bilinear_interp_v2"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1]
scale_h = self.scale[0]
out_h = int(self.input_shape[2] * scale_h)
out_w = int(self.input_shape[3] * scale_w)
else:
out_h = self.out_h
out_w = self.out_w
output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 3, 9, 6]
self.out_h = 10
self.out_w = 9
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.scale = 0.
self.out_size = np.array([6, 15]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpOtherMethod1(TestBilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 1
class TestBilinearInterpWithMethod2(TestBilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 0
class TestBilinearInterpWithMethod3(TestBilinearInterpOp):
def set_align_mode(self):
self.align_corners = True
self.align_mode = 0
class TestBilinearInterpScale1(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 2.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpScale2(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 1.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpScale3(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 1.5
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpScale4(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = [1.5, 0.5]
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpZero(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 0.2
self.align_corners = False
self.align_mode = 0
class TestBilinearInterpOp_attr_tensor(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "bilinear_interp_v2"
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
self.attrs = {
'interp_method': self.interp_method,
'align_corners': self.align_corners,
}
input_np = np.random.random(self.input_shape).astype("float64")
self.inputs = {'X': input_np}
if self.scale_by_1Dtensor:
self.inputs['Scale'] = np.array([self.scale]).astype("float32")
elif self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1]
scale_h = self.scale[0]
out_h = int(self.input_shape[2] * scale_h)
out_w = int(self.input_shape[3] * scale_w)
else:
out_h = self.out_h
out_w = self.out_w
if self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.out_size is not None:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs['out_h'] = self.out_h
self.attrs['out_w'] = self.out_w
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size,
self.actual_shape, self.align_corners)
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 5]
self.out_h = 3
self.out_w = 3
self.scale = 0.
self.out_size = [3, 3]
self.align_corners = True
# out_size is a 1-D tensor
class TestBilinearInterp_attr_tensor_Case1(TestBilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = [8, 12]
self.align_corners = True
# scale is a 1-D tensor
class TestBilinearInterp_attr_tensor_Case2(TestBilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
self.shape_by_1Dtensor = True
# scale is a 1-D tensor
class TestBilinearInterp_attr_tensor_Case3(TestBilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 2.0
self.out_size = None
self.align_corners = True
self.scale_by_1Dtensor = True
class TestBilinearInterpOpAPI(unittest.TestCase):
def test_case(self):
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
dim = fluid.data(name="dim", shape=[1], dtype="int32")
shape_tensor = fluid.data(name="shape_tensor", shape=[2], dtype="int32")
actual_size = fluid.data(name="actual_size", shape=[2], dtype="int32")
scale_tensor = fluid.data(
name="scale_tensor", shape=[1], dtype="float32")
out1 = fluid.layers.resize_bilinear(x, out_shape=[12, 12])
out2 = fluid.layers.resize_bilinear(x, out_shape=[12, dim])
out3 = fluid.layers.resize_bilinear(x, out_shape=shape_tensor)
out4 = fluid.layers.resize_bilinear(
x, out_shape=[4, 4], actual_shape=actual_size)
out5 = fluid.layers.resize_bilinear(x, scale=scale_tensor)
x_data = np.random.random((2, 3, 6, 6)).astype("float32")
dim_data = np.array([12]).astype("int32")
shape_data = np.array([12, 12]).astype("int32")
actual_size_data = np.array([12, 12]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = bilinear_interp_np(
x_data, out_h=12, out_w=12, align_corners=True)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
class TestUpsampleBilinear2dInterpOpAPI2_0(unittest.TestCase):
def test_case(self):
# dygraph
x_data = np.random.random((1, 3, 6, 6)).astype("float32")
upsample = paddle.nn.UpsamplingBilinear2d(scale_factor=[2, 2])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(x_data)
interp = upsample(x)
expect = bilinear_interp_np(
x_data, out_h=12, out_w=12, align_corners=True)
self.assertTrue(np.allclose(interp.numpy(), expect))
class TestBilinearInterpOpAPI_dy(unittest.TestCase):
def test_case(self):
import paddle
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
with fluid.dygraph.guard(place):
input_data = np.random.random((2, 3, 6, 6)).astype("float32")
input_x = paddle.to_tensor(input_data)
expect_res = bilinear_interp_np(
input_data, out_h=12, out_w=12, align_corners=False)
out = interpolate(
x=input_x, size=[12, 12], mode="bilinear", align_corners=False)
self.assertTrue(np.allclose(out.numpy(), expect_res))
if __name__ == "__main__":
unittest.main()
......@@ -21,7 +21,7 @@ import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.nn.functional import *
from paddle.nn.functional import interpolate
def linear_interp_np(input,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import platform
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.nn.functional import interpolate
def linear_interp_np(input,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
align_mode=0,
data_layout='NCHW'):
if data_layout == "NHWC":
input = np.transpose(input, (0, 2, 1)) # NHWC => NCHW
if out_size is not None:
out_w = out_size[0]
if actual_shape is not None:
out_w = actual_shape[0]
batch_size, channel, in_w = input.shape
ratio_w = 0.0
if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((batch_size, channel, out_w))
for j in range(out_w):
if (align_mode == 0 and not align_corners):
w = int(ratio_w * (j + 0.5) - 0.5)
else:
w = int(ratio_w * j)
w = max(0, w)
wid = 1 if w < in_w - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_w = max(ratio_w * (j + 0.5) - 0.5, 0)
w1lambda = idx_src_w - w
else:
w1lambda = ratio_w * j - w
w2lambda = 1.0 - w1lambda
out[:, :, j] = w2lambda * input[:, :, w] + w1lambda * input[:, :, w +
wid]
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 1)) # NCHW => NHWC
return out.astype(input.dtype)
class TestLinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "linear_interp_v2"
input_np = np.random.random(self.input_shape).astype("float64")
if self.data_layout == "NCHW":
in_w = self.input_shape[2]
else:
in_w = self.input_shape[1]
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
self.scale = float(self.scale)
if isinstance(self.scale, list):
self.scale = float(self.scale[0])
out_w = int(in_w * self.scale)
else:
out_w = self.out_w
output_np = linear_interp_np(input_np, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode,
'data_layout': self.data_layout
}
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
self.scale = [float(self.scale)]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
if platform.system() == "Linux":
self.check_output(atol=1e-7)
else:
self.check_output(atol=1e-5)
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [1, 3, 100]
self.out_w = 50
self.scale = 0.
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = False
self.align_mode = 1
class TestLinearInterpOpDataLayout(TestLinearInterpOp):
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [1, 3, 100]
self.out_w = 50
self.scale = 0.
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = False
self.align_mode = 1
self.data_layout = 'NHWC'
class TestLinearInterpOpAlignMode(TestLinearInterpOp):
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [1, 3, 100]
self.out_w = 50
self.scale = 0.
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = False
self.align_mode = 0
class TestLinearInterpOpScale(TestLinearInterpOp):
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [1, 3, 100]
self.out_w = 50
self.scale = 0.5
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = False
self.align_mode = 0
class TestLinearInterpOpSizeTensor(TestLinearInterpOp):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "linear_interp_v2"
input_np = np.random.random(self.input_shape).astype("float64")
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
if self.data_layout == "NCHW":
in_w = self.input_shape[2]
else:
in_w = self.input_shape[1]
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
self.scale = float(self.scale)
if isinstance(self.scale, list):
self.scale = float(self.scale[0])
out_w = int(in_w * self.scale)
else:
out_w = self.out_w
output_np = linear_interp_np(input_np, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None and self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.actual_shape is not None and self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.actual_shape
else:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs = {
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode,
'data_layout': self.data_layout
}
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
class TestResizeLinearAPI(unittest.TestCase):
def test_case(self):
x = fluid.data(name="x", shape=[1, 3, 64], dtype="float32")
dim = fluid.data(name="dim", shape=[1], dtype="int32")
shape_tensor = fluid.data(name="shape_tensor", shape=[1], dtype="int32")
actual_size = fluid.data(name="actual_size", shape=[1], dtype="int32")
scale_tensor = fluid.data(
name="scale_tensor", shape=[1], dtype="float32")
out1 = fluid.layers.resize_linear(
x, out_shape=[128, ], align_mode=1, align_corners=False)
out2 = fluid.layers.resize_linear(
x, out_shape=[128], align_mode=1, align_corners=False)
out3 = fluid.layers.resize_linear(
x, out_shape=shape_tensor, align_mode=1, align_corners=False)
out4 = fluid.layers.resize_linear(
x,
out_shape=[128, ],
actual_shape=actual_size,
align_mode=1,
align_corners=False)
out5 = fluid.layers.resize_linear(
x, scale=scale_tensor, align_mode=1, align_corners=False)
out6 = interpolate(
x,
scale_factor=scale_tensor,
mode='linear',
align_mode=1,
align_corners=False,
data_format='NCW')
out7 = interpolate(
x,
size=[128, ],
mode='linear',
align_mode=1,
align_corners=False,
data_format='NCW')
out8 = interpolate(
x,
size=shape_tensor,
mode='linear',
align_mode=1,
align_corners=False,
data_format='NCW')
x_data = np.random.random((1, 3, 64)).astype("float32")
dim_data = np.array([128]).astype("int32")
shape_data = np.array([128, ]).astype("int32")
actual_size_data = np.array([128, ]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(
fluid.default_main_program(),
feed={
"x": x_data,
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5, out6, out7, out8],
return_numpy=True)
expect_res = linear_interp_np(
x_data, out_w=128, align_mode=1, align_corners=False)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
class TestLinearInterpOpAPI2_0(unittest.TestCase):
def test_case(self):
# dygraph
x_data = np.random.random((1, 3, 128)).astype("float32")
us_1 = paddle.nn.UpSample(
size=[64, ],
mode='linear',
align_mode=1,
align_corners=False,
data_format='NCW')
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(x_data)
interp = us_1(x)
expect = linear_interp_np(
x_data, out_w=64, align_mode=1, align_corners=False)
self.assertTrue(np.allclose(interp.numpy(), expect))
class TestResizeLinearOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "linear_interp_v2"
input_np = np.random.random(self.input_shape).astype("uint8")
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
self.scale = float(self.scale)
if isinstance(self.scale, list):
self.scale = float(self.scale[0])
out_w = int(self.input_shape[2] * self.scale)
else:
out_w = self.out_w
output_np = linear_interp_np(input_np, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
if platform.system() == "Linux":
self.check_output_with_place(place=core.CPUPlace(), atol=1e-7)
else:
self.check_output_with_place(place=core.CPUPlace(), atol=1e-5)
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [2, 3, 100]
self.out_w = 50
self.scale = 0.
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestLinearInterpOpException(unittest.TestCase):
def test_exception(self):
def input_shape_error():
x1 = fluid.data(name="x1", shape=[1], dtype="float32")
out = fluid.layers.resize_linear(
x1, out_shape=[256, ], data_format='NCW')
def data_format_error():
x2 = fluid.data(name="x2", shape=[1, 3, 128], dtype="float32")
out = fluid.layers.resize_linear(
x2, out_shape=[256, ], data_format='NHWCD')
def out_shape_error():
x3 = fluid.data(name="x3", shape=[1, 3, 128], dtype="float32")
out = fluid.layers.resize_linear(
x3, out_shape=[
256,
256,
], data_format='NHWC')
self.assertRaises(ValueError, input_shape_error)
self.assertRaises(ValueError, data_format_error)
self.assertRaises(ValueError, out_shape_error)
class TestLinearInterpOpError(unittest.TestCase):
def test_error(self):
with program_guard(Program(), Program()):
def input_shape_error():
x1 = fluid.data(name="x1", shape=[1], dtype="float32")
out1 = paddle.nn.UpSample(
size=[256, ], data_format='NCW', mode='linear')
out1_res = out1(x1)
def data_format_error():
x2 = fluid.data(name="x2", shape=[1, 3, 128], dtype="float32")
out2 = paddle.nn.UpSample(
size=[256, ], data_format='NHWCD', mode='linear')
out2_res = out2(x2)
def out_shape_error():
x3 = fluid.data(name="x3", shape=[1, 3, 128], dtype="float32")
out3 = paddle.nn.UpSample(
size=[
256,
256,
], data_format='NHWC', mode='linear')
out3_res = out3(x3)
self.assertRaises(ValueError, input_shape_error)
self.assertRaises(ValueError, data_format_error)
self.assertRaises(ValueError, out_shape_error)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.nn as nn
import paddle
def nearest_neighbor_interp_np(X,
out_h,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
data_layout='NCHW'):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
X = np.transpose(X, (0, 3, 1, 2)) # NHWC => NCHW
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
n, c, in_h, in_w = X.shape
ratio_h = ratio_w = 0.0
if (out_h > 1):
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0)
else:
ratio_h = 1.0 * in_h / out_h
if (out_w > 1):
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((n, c, out_h, out_w))
if align_corners:
for i in range(out_h):
in_i = int(ratio_h * i + 0.5)
for j in range(out_w):
in_j = int(ratio_w * j + 0.5)
out[:, :, i, j] = X[:, :, in_i, in_j]
else:
for i in range(out_h):
in_i = int(ratio_h * i)
for j in range(out_w):
in_j = int(ratio_w * j)
out[:, :, i, j] = X[:, :, in_i, in_j]
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
return out.astype(X.dtype)
class TestNearestInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "nearest_interp_v2"
input_np = np.random.random(self.input_shape).astype("float64")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1]
scale_h = self.scale[0]
out_h = int(in_h * scale_h)
out_w = int(in_w * scale_w)
else:
out_h = self.out_h
out_w = self.out_w
output_np = nearest_neighbor_interp_np(
input_np, out_h, out_w, self.out_size, self.actual_shape,
self.align_corners, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'data_layout': self.data_layout
}
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 4, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
class TestNearestNeighborInterpCase1(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.align_corners = True
class TestNearestNeighborInterpCase2(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.align_corners = True
class TestNearestNeighborInterpCase3(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.align_corners = True
class TestNearestNeighborInterpCase4(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2]).astype("int32")
self.align_corners = True
class TestNearestNeighborInterpCase5(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = np.array([11, 11]).astype("int32")
self.align_corners = True
class TestNearestNeighborInterpCase6(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([65, 129]).astype("int32")
self.align_corners = True
class TestNearestNeighborInterpSame(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = 0.
self.align_corners = True
class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
class TestNearestNeighborInterpDataLayout(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 4, 4, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 8]).astype("int32")
self.align_corners = True
self.data_layout = "NHWC"
class TestNearestInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "nearest_interp_v2"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1]
scale_h = self.scale[0]
out_h = int(self.input_shape[2] * scale_h)
out_w = int(self.input_shape[3] * scale_w)
else:
out_h = self.out_h
out_w = self.out_w
output_np = nearest_neighbor_interp_np(input_np, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners
}
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 3, 9, 6]
self.out_h = 10
self.out_w = 9
self.scale = 0.
self.align_corners = True
class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 32, 64]
self.out_h = 80
self.out_w = 40
self.scale = 0.
self.align_corners = True
class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.scale = 0.
self.out_size = np.array([6, 15]).astype("int32")
self.align_corners = True
class TestNearestInterpWithoutCorners(TestNearestInterpOp):
def set_align_corners(self):
self.align_corners = False
class TestNearestNeighborInterpScale1(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 7, 5]
self.out_h = 64
self.out_w = 32
self.scale = 2.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
class TestNearestNeighborInterpScale2(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 5, 7]
self.out_h = 64
self.out_w = 32
self.scale = 1.5
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
class TestNearestNeighborInterpScale3(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 7, 5]
self.out_h = 64
self.out_w = 32
self.scale = [2.0, 3.0]
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
class TestNearestInterpOp_attr_tensor(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "nearest_interp_v2"
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
self.attrs = {
'interp_method': self.interp_method,
'align_corners': self.align_corners,
}
input_np = np.random.random(self.input_shape).astype("float64")
self.inputs = {'X': input_np}
if self.scale_by_1Dtensor:
self.inputs['Scale'] = np.array([self.scale]).astype("float64")
elif self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1]
scale_h = self.scale[0]
out_h = int(self.input_shape[2] * scale_h)
out_w = int(self.input_shape[3] * scale_w)
else:
out_h = self.out_h
out_w = self.out_w
if self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.out_size is not None:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs['out_h'] = self.out_h
self.attrs['out_w'] = self.out_w
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
output_np = nearest_neighbor_interp_np(input_np, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners)
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 5, 4, 4]
self.out_h = 3
self.out_w = 3
self.scale = 0.
self.out_size = [3, 3]
self.align_corners = True
# out_size is a tensor list
class TestNearestInterp_attr_tensor_Case1(TestNearestInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = [8, 12]
self.align_corners = True
# out_size is a 1-D tensor
class TestNearestInterp_attr_tensor_Case2(TestNearestInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
self.shape_by_1Dtensor = True
# scale is a 1-D tensor
class TestNearestInterp_attr_tensor_Case3(TestNearestInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 2.0
self.out_size = None
self.align_corners = True
self.scale_by_1Dtensor = True
class TestNearestAPI(unittest.TestCase):
def test_case(self):
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
y = fluid.data(name="y", shape=[2, 6, 6, 3], dtype="float32")
dim = fluid.data(name="dim", shape=[1], dtype="int32")
shape_tensor = fluid.data(name="shape_tensor", shape=[2], dtype="int32")
actual_size = fluid.data(name="actual_size", shape=[2], dtype="int32")
scale_tensor = fluid.data(
name="scale_tensor", shape=[1], dtype="float32")
out1 = fluid.layers.resize_nearest(
y, out_shape=[12, 12], data_format='NHWC')
out2 = fluid.layers.resize_nearest(x, out_shape=[12, dim])
out3 = fluid.layers.resize_nearest(x, out_shape=shape_tensor)
out4 = fluid.layers.resize_nearest(
x, out_shape=[4, 4], actual_shape=actual_size)
out5 = fluid.layers.resize_nearest(x, scale=scale_tensor)
x_data = np.random.random((2, 3, 6, 6)).astype("float32")
dim_data = np.array([12]).astype("int32")
shape_data = np.array([12, 12]).astype("int32")
actual_size_data = np.array([12, 12]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"y": np.transpose(x_data, (0, 2, 3, 1)),
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = nearest_neighbor_interp_np(
x_data, out_h=12, out_w=12, align_corners=True)
self.assertTrue(
np.allclose(results[0], np.transpose(expect_res, (0, 2, 3, 1))))
for i in range(len(results) - 1):
self.assertTrue(np.allclose(results[i + 1], expect_res))
class TestUpsampleNearest2dInterpOpAPI2_0(unittest.TestCase):
def test_case(self):
# dygraph
x_data = np.random.random((1, 3, 6, 6)).astype("float32")
upsample = paddle.nn.UpsamplingNearest2d(scale_factor=[2, 2])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(x_data)
interp = upsample(x)
expect = nearest_neighbor_interp_np(
x_data, out_h=12, out_w=12, align_corners=False)
self.assertTrue(np.allclose(interp.numpy(), expect))
class TestNearestInterpException(unittest.TestCase):
def test_exception(self):
input = fluid.data(name="input", shape=[1, 3, 6, 6], dtype="float32")
def attr_data_format():
# for 4-D input, data_format can only be NCHW or NHWC
out = fluid.layers.resize_nearest(
input, out_shape=[4, 8], data_format='NDHWC')
def attr_scale_type():
out = fluid.layers.resize_nearest(input, scale='scale')
def attr_scale_value():
out = fluid.layers.resize_nearest(input, scale=-0.3)
self.assertRaises(ValueError, attr_data_format)
self.assertRaises(TypeError, attr_scale_type)
self.assertRaises(ValueError, attr_scale_value)
if __name__ == "__main__":
unittest.main()
......@@ -17,7 +17,7 @@ import unittest
from op_test import OpTest
import numpy as np
import paddle.fluid.core as core
from paddle.nn.functional import *
from paddle.nn.functional import avg_pool2d, max_pool2d
import paddle.fluid as fluid
import paddle
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.nn.functional import interpolate
def trilinear_interp_np(input,
out_d,
out_h,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
align_mode=0,
data_layout='NCDHW'):
"""trilinear interpolation implement in shape [N, C, D, H, W]"""
if data_layout == "NDHWC":
input = np.transpose(input, (0, 4, 1, 2, 3)) # NDHWC => NCDHW
if out_size is not None:
out_d = out_size[0]
out_h = out_size[1]
out_w = out_size[2]
if actual_shape is not None:
out_d = actual_shape[0]
out_h = actual_shape[1]
out_w = actual_shape[2]
batch_size, channel, in_d, in_h, in_w = input.shape
ratio_d = ratio_h = ratio_w = 0.0
if out_d > 1:
if (align_corners):
ratio_d = (in_d - 1.0) / (out_d - 1.0)
else:
ratio_d = 1.0 * in_d / out_d
if out_h > 1:
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0)
else:
ratio_h = 1.0 * in_h / out_h
if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((batch_size, channel, out_d, out_h, out_w))
for i in range(out_d):
if (align_mode == 0 and not align_corners):
d = int(ratio_d * (i + 0.5) - 0.5)
else:
d = int(ratio_d * i)
d = max(0, d)
did = 1 if d < in_d - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_d = max(ratio_d * (i + 0.5) - 0.5, 0)
d1lambda = idx_src_d - d
else:
d1lambda = ratio_d * i - d
d2lambda = 1.0 - d1lambda
for j in range(out_h):
if (align_mode == 0 and not align_corners):
h = int(ratio_h * (j + 0.5) - 0.5)
else:
h = int(ratio_h * j)
h = max(0, h)
hid = 1 if h < in_h - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_h = max(ratio_h * (j + 0.5) - 0.5, 0)
h1lambda = idx_src_h - h
else:
h1lambda = ratio_h * j - h
h2lambda = 1.0 - h1lambda
for k in range(out_w):
if (align_mode == 0 and not align_corners):
w = int(ratio_w * (k + 0.5) - 0.5)
else:
w = int(ratio_w * k)
w = max(0, w)
wid = 1 if w < in_w - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_w = max(ratio_w * (k + 0.5) - 0.5, 0)
w1lambda = idx_src_w - w
else:
w1lambda = ratio_w * k - w
w2lambda = 1.0 - w1lambda
out[:, :, i, j, k] = \
d2lambda * \
(h2lambda * (w2lambda * input[:, :, d, h, w] + \
w1lambda * input[:, :, d, h, w+wid]) + \
h1lambda * (w2lambda * input[:, :, d, h+hid, w] + \
w1lambda * input[:, :, d, h+hid, w+wid])) + \
d1lambda * \
(h2lambda * (w2lambda * input[:, :, d+did, h, w] + \
w1lambda * input[:, :, d+did, h, w+wid]) + \
h1lambda * (w2lambda * input[:, :, d+did, h+hid, w] + \
w1lambda * input[:, :, d+did, h+hid, w+wid]))
if data_layout == "NDHWC":
out = np.transpose(out, (0, 2, 3, 4, 1)) # NCDHW => NDHWC
return out.astype(input.dtype)
class TestTrilinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCDHW'
self.init_test_case()
self.op_type = "trilinear_interp_v2"
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCDHW":
in_d = self.input_shape[2]
in_h = self.input_shape[3]
in_w = self.input_shape[4]
else:
in_d = self.input_shape[1]
in_h = self.input_shape[2]
in_w = self.input_shape[3]
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
scale_d = scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_d = scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[2]
scale_h = self.scale[1]
scale_d = self.scale[0]
out_d = int(in_d * scale_d)
out_h = int(in_h * scale_h)
out_w = int(in_w * scale_w)
else:
out_d = self.out_d
out_h = self.out_h
out_w = self.out_w
output_np = trilinear_interp_np(
input_np, out_d, out_h, out_w, self.out_size, self.actual_shape,
self.align_corners, self.align_mode, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
# c++ end treat NCDHW the same way as NCHW
if self.data_layout == 'NCDHW':
data_layout = 'NCHW'
else:
data_layout = 'NHWC'
self.attrs = {
'out_d': self.out_d,
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode,
'data_layout': data_layout
}
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 4, 4, 4]
self.out_d = 2
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase1(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 1, 7, 8, 9]
self.out_d = 1
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase2(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 9, 6, 8]
self.out_d = 12
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase3(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 2, 16, 8, 4]
self.out_d = 32
self.out_h = 16
self.out_w = 8
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase4(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [4, 1, 7, 8, 9]
self.out_d = 1
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2, 2]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase5(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 3, 9, 6, 8]
self.out_d = 12
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = np.array([11, 11, 11]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase6(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 1, 16, 8, 4]
self.out_d = 8
self.out_h = 32
self.out_w = 16
self.scale = 0.
self.out_size = np.array([17, 9, 5]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpSame(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 1, 16, 8, 4]
self.out_d = 16
self.out_h = 8
self.out_w = 4
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpSameHW(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 1, 16, 8, 4]
self.out_d = 8
self.out_h = 8
self.out_w = 4
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpActualShape(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 2, 16, 8, 4]
self.out_d = 64
self.out_h = 32
self.out_w = 16
self.scale = 0.
self.out_size = np.array([33, 19, 7]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpDatalayout(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 4, 4, 4, 3]
self.out_d = 2
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
self.data_layout = "NDHWC"
class TestTrilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "trilinear_interp_v2"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
scale_d = scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_d = scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[2]
scale_h = self.scale[1]
scale_d = self.scale[0]
out_d = int(self.input_shape[2] * scale_d)
out_h = int(self.input_shape[3] * scale_h)
out_w = int(self.input_shape[4] * scale_w)
else:
out_d = self.out_d
out_h = self.out_h
out_w = self.out_w
output_np = trilinear_interp_np(input_np, out_d, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners, self.align_mode)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {
'out_d': self.out_d,
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [1, 3, 9, 6, 8]
self.out_d = 13
self.out_h = 10
self.out_w = 9
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase1Uint8(TestTrilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 16, 8, 4]
self.out_d = 13
self.out_h = 7
self.out_w = 2
self.scale = 0.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpCase2Uint8(TestTrilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [4, 1, 7, 8, 9]
self.out_d = 3
self.out_h = 5
self.out_w = 13
self.scale = 0.
self.out_size = np.array([6, 15, 21]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpOtherMethod1(TestTrilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 1
class TestTrilinearInterpWithMethod2(TestTrilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 0
class TestTrilinearInterpWithMethod3(TestTrilinearInterpOp):
def set_align_mode(self):
self.align_corners = True
self.align_mode = 0
class TestTrilinearInterpScale1(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 9]
self.out_d = 82
self.out_h = 60
self.out_w = 25
self.scale = 2.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpScale2(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 9]
self.out_d = 60
self.out_h = 40
self.out_w = 25
self.scale = 1.
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpScale3(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 9]
self.out_d = 60
self.out_h = 40
self.out_w = 25
self.scale = 1.5
self.align_corners = True
self.align_mode = 1
class TestTrilinearInterpZero(TestTrilinearInterpOp):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 5, 7, 11]
self.out_d = 60
self.out_h = 40
self.out_w = 25
self.scale = 0.2
self.align_corners = False
self.align_mode = 0
class TestTrilinearInterpOp_attr_tensor(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "trilinear_interp_v2"
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
self.attrs = {
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
input_np = np.random.random(self.input_shape).astype("float32")
self.inputs = {'X': input_np}
if self.scale_by_1Dtensor:
self.inputs['Scale'] = np.array([self.scale]).astype("float32")
elif self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
scale_d = scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_d = scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[2]
scale_h = self.scale[1]
scale_d = self.scale[0]
out_d = int(self.input_shape[2] * scale_d)
out_h = int(self.input_shape[3] * scale_h)
out_w = int(self.input_shape[4] * scale_w)
else:
out_d = self.out_d
out_h = self.out_h
out_w = self.out_w
if self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.out_size is not None:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs['out_d'] = self.out_d
self.attrs['out_h'] = self.out_h
self.attrs['out_w'] = self.out_w
if self.scale > 0:
if isinstance(self.scale, float) or isinstance(self.scale, int):
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
output_np = trilinear_interp_np(input_np, out_d, out_h, out_w,
self.out_size, self.actual_shape,
self.align_corners, self.align_mode)
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 4, 4, 4]
self.out_d = 2
self.out_h = 3
self.out_w = 3
self.scale = 0.
self.out_size = [2, 3, 3]
self.align_corners = True
self.align_mode = 1
# out_size is a 1-D tensor
class TestTrilinearInterp_attr_tensor_Case1(TestTrilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [3, 2, 9, 6, 8]
self.out_d = 32
self.out_h = 16
self.out_w = 8
self.scale = 0.3
self.out_size = [12, 4, 4]
self.align_corners = True
self.align_mode = 1
# scale is a 1-D tensor
class TestTrilinearInterp_attr_tensor_Case2(TestTrilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 8, 8, 4]
self.out_d = 16
self.out_h = 12
self.out_w = 4
self.scale = 0.
self.out_size = [16, 4, 10]
self.align_corners = True
self.align_mode = 1
self.shape_by_1Dtensor = True
# scale is a 1-D tensor
class TestTrilinearInterp_attr_tensor_Case3(TestTrilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'trilinear'
self.input_shape = [2, 3, 8, 8, 4]
self.out_d = 16
self.out_h = 16
self.out_w = 8
self.scale = 2.0
self.out_size = None
self.align_corners = True
self.align_mode = 1
self.scale_by_1Dtensor = True
class TestTrilinearInterpAPI(unittest.TestCase):
def test_case(self):
x = fluid.data(name="x", shape=[2, 3, 6, 9, 4], dtype="float32")
y = fluid.data(name="y", shape=[2, 6, 9, 4, 3], dtype="float32")
dim = fluid.data(name="dim", shape=[1], dtype="int32")
shape_tensor = fluid.data(name="shape_tensor", shape=[3], dtype="int32")
actual_size = fluid.data(name="actual_size", shape=[3], dtype="int32")
scale_tensor = fluid.data(
name="scale_tensor", shape=[1], dtype="float32")
out1 = fluid.layers.resize_trilinear(
y, out_shape=[12, 18, 8], data_format='NDHWC')
out2 = fluid.layers.resize_trilinear(x, out_shape=[12, dim, 8])
out3 = fluid.layers.resize_trilinear(x, out_shape=shape_tensor)
out4 = fluid.layers.resize_trilinear(
x, out_shape=[4, 4, 8], actual_shape=actual_size)
out5 = fluid.layers.resize_trilinear(x, scale=scale_tensor)
out6 = interpolate(
x, scale_factor=scale_tensor, mode='trilinear', data_format="NCDHW")
out7 = interpolate(
x, size=[4, 4, 8], mode='trilinear', data_format="NCDHW")
out8 = interpolate(
x, size=shape_tensor, mode='trilinear', data_format="NCDHW")
x_data = np.random.random((2, 3, 6, 9, 4)).astype("float32")
dim_data = np.array([18]).astype("int32")
shape_data = np.array([12, 18, 8]).astype("int32")
actual_size_data = np.array([12, 18, 8]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"y": np.transpose(x_data, (0, 2, 3, 4, 1)),
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = trilinear_interp_np(
x_data, out_d=12, out_h=18, out_w=8, align_mode=1)
self.assertTrue(
np.allclose(results[0], np.transpose(expect_res, (0, 2, 3, 4, 1))))
for i in range(len(results) - 1):
self.assertTrue(np.allclose(results[i + 1], expect_res))
class TestTrilinearInterpOpException(unittest.TestCase):
def test_exception(self):
input = fluid.data(name="input", shape=[2, 3, 6, 9, 4], dtype="float32")
def attr_data_format():
# for 5-D input, data_format only can be NCDHW or NDHWC
out = fluid.layers.resize_trilinear(
input, out_shape=[4, 8, 4], data_format='NHWC')
self.assertRaises(ValueError, attr_data_format)
if __name__ == "__main__":
unittest.main()
......@@ -73,6 +73,7 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
'mish', \
'transpose2', \
'trilinear_interp', \
'trilinear_interp_v2', \
'var_conv_2d', \
'warpctc', \
'bilateral_slice'
......
......@@ -15,6 +15,7 @@
NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'affine_channel', \
'bilinear_interp', \
'bilinear_interp_v2',\
'bilinear_tensor_product', \
'conv2d', \
'conv3d', \
......@@ -45,4 +46,6 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'cudnn_lstm'
]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp']
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\
'bilinear_interp_v2'
]
......@@ -88,6 +88,8 @@ from .layer.common import Embedding #DEFINE_ALIAS
from .layer.common import Linear #DEFINE_ALIAS
from .layer.common import Flatten #DEFINE_ALIAS
from .layer.common import UpSample #DEFINE_ALIAS
from .layer.common import UpsamplingNearest2d #DEFINE_ALIAS
from .layer.common import UpsamplingBilinear2d #DEFINE_ALIAS
from .layer.common import Bilinear #DEFINE_ALIAS
from .layer.common import Dropout #DEFINE_ALIAS
from .layer.common import Dropout2D #DEFINE_ALIAS
......
......@@ -54,30 +54,28 @@ __all__ = [
# 'bilinear_tensor_product',
'assign',
'interpolate',
'upsample',
'bilinear',
'cosine_similarity',
]
def interpolate(input,
def interpolate(x,
size=None,
scale_factor=None,
mode='nearest',
align_corners=False,
align_mode=1,
align_mode=0,
data_format='NCHW',
name=None):
"""
:alias_main: paddle.nn.functional.interpolate
:alias: paddle.nn.functional.interpolate,paddle.nn.functional.common.interpolate
This op resizes a batch of images.
The input must be a 3-D Tensor of the shape (num_batches, channels, in_w)
or 4-D (num_batches, channels, in_h, in_w), or a 5-D Tensor of the shape
(num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
and the resizing only applies on the three dimensions(depth, height and width).
**Warning:** the parameter :attr:`actual_shape` will be deprecated in the
future and only use :attr:`out_shape` instead.
Supporting resample methods:
'linear' : Linear interpolation
'bilinear' : Bilinear interpolation
......@@ -102,7 +100,7 @@ def interpolate(input,
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optional parameters,the calculation method
align_corners and align_mode are optional parameters,the calculation method
of interpolation can be selected by them.
Bicubic interpolation is an extension of cubic interpolation for interpolating
......@@ -132,18 +130,12 @@ def interpolate(input,
W_out = W_{in} * scale_{factor}
Nearest neighbor interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = floor (H_{in} * scale_{factor})
W_out = floor (W_{in} * scale_{factor})
else:
align_corners = True
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = round(H_{in} * scale_{factor})
W_out = round(W_{in} * scale_{factor})
Bilinear interpolation:
if:
......@@ -202,22 +194,22 @@ def interpolate(input,
https://en.wikipedia.org/wiki/Bicubic_interpolation
Parameters:
input (Variable): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
x (Tensor): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
size (list|tuple|Variable|None): Output shape of image resize
size (list|tuple|Tensor|None): Output shape of image resize
layer, the shape is (out_w, ) when input is a 3-D Tensor, the shape is (out_h, out_w)
when input is a 4-D Tensor and is (out_d, out_h, out_w) when input is a 5-D Tensor.
Default: None. If a list, each element can be an integer or a Tensor Variable of shape: [1].
If a Tensor Variable, its dimensions size should be a 1.
scale_factor (float|Variable|None): The multiplier for the input height or width. At
scale_factor (float|Tensor|list|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale_factor` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale_factor`.
And :attr:`out_shape` has a higher priority than :attr:`scale_factor`.Has to match input size if it is a list.
Default: None.
mode (str): The resample method. It supports 'linear', 'nearest', 'bilinear',
'bicubic' and 'trilinear' currently. Default: 'nearest'
align_corners(bool) : An optional bool, If True, the centers of the 4 corner pixels of the
input and output tensors are aligned, preserving the values at the
corner pixels.
corner pixels.This only has an effect when 'linear', 'bilinear', 'bicubic' or 'trilinear'.
Default: False
align_mode(int) : An optional for linear/bilinear/trilinear interpolation. Refer to the formula in the example above,
it can be \'0\' for src_idx = scale_factor*(dst_indx+0.5)-0.5 , can be \'1\' for
......@@ -235,7 +227,7 @@ def interpolate(input,
A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
Raises:
TypeError: size should be a list or tuple or Variable.
TypeError: size should be a list or tuple or Tensor.
ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear',
'trilinear', 'bicubic', or 'nearest' currently.
ValueError: 'linear' only support 3-D tensor.
......@@ -253,53 +245,27 @@ def interpolate(input,
Examples:
.. code-block:: python
#declarative mode
import paddle
import numpy as np
input = fluid.data(name="input", shape=[None,3,6,10])
#1
output = paddle.nn.functional.interpolate(input=input, size=[12,12])
#2
#x = np.array([2]).astype("int32")
#dim1 = fluid.data(name="dim1", shape=[1], dtype="int32")
#fluid.layers.assign(input=x, output=dim1)
#output = paddle.nn.functional.interpolate(input=input, size=[12,dim1])
#3
#x = np.array([3,12]).astype("int32")
#shape_tensor = fluid.data(name="shape_tensor", shape=[2], dtype="int32")
#fluid.layers.assign(input=x, output=shape_tensor)
#output = paddle.nn.functional.interpolate(input=input, size=shape_tensor)
#4
#x = np.array([0.5]).astype("float32")
#scale_tensor = fluid.data(name="scale", shape=[1], dtype="float32")
#fluid.layers.assign(x,scale_tensor)
#output = paddle.nn.functional.interpolate(input=input, scale_factor=scale_tensor)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
import paddle.nn.functional as F
paddle.disable_static()
# given out size
input_data = np.random.rand(2,3,6,10).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data},
fetch_list=[output],
return_numpy=True)
print(output_data[0].shape)
#1
# (2, 3, 12, 12)
#2
# (2, 3, 12, 2)
#3
# (2, 3, 3, 12)
#4
# (2, 3, 3, 5)
#imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
output = paddle.nn.functional.interpolate(input=input, size=[12,12])
print(output.shape)
x = paddle.to_tensor(input_data)
output_1 = F.interpolate(x=x, size=[12,12])
print(output_1.shape)
# [2L, 3L, 12L, 12L]
# given scale
output_2 = F.interpolate(x=x, scale_factor=[2,1])
print(output_2.shape)
# [2L, 3L, 12L, 10L]
# bilinear interp
output_3 = F.interpolate(x=x, scale_factor=[2,1], mode="bilinear")
print(output_2.shape)
# [2L, 3L, 12L, 10L]
"""
data_format = data_format.upper()
resample = mode.upper()
......@@ -317,13 +283,13 @@ def interpolate(input,
"The 'resample' of image_resize can only be 'linaer', 'bilinear', 'trilinear', "
" 'bicubic' or 'nearest' currently.")
if resample in ['LINEAR'] and len(input.shape) != 3:
if resample in ['LINEAR'] and len(x.shape) != 3:
raise ValueError("'linear' only support 3-D tensor.")
if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(input.shape) != 4:
if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(x.shape) != 4:
raise ValueError(
"'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.")
if resample == 'TRILINEAR' and len(input.shape) != 5:
if resample == 'TRILINEAR' and len(x.shape) != 5:
raise ValueError("'trilinear'only support 5-D tensor.")
if size is None and scale_factor is None:
......@@ -334,19 +300,21 @@ def interpolate(input,
if align_mode != 0 and align_mode != 1:
raise ValueError("align_mode can only be 0 or 1")
helper = LayerHelper('{}_interp'.format(resample_type), **locals())
if align_corners != 0 and resample == 'NEAREST':
raise ValueError(
"align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
)
helper = LayerHelper('{}_interp_v2'.format(resample_type), **locals())
dtype = helper.input_dtype()
if len(input.shape) == 3 and data_format not in ['NCW', 'NWC']:
if len(x.shape) == 3 and data_format not in ['NCW', 'NWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCW` or `NWC` supported for 3-D input.")
elif len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']:
elif len(x.shape) == 4 and data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCHW` or `NHWC` supported for 4-D input.")
elif len(input.shape) == 5 and data_format not in ['NCDHW', 'NDHWC']:
elif len(x.shape) == 5 and data_format not in ['NCDHW', 'NDHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCDHW` or `NDHWC` supported for 5-D input.")
......@@ -359,7 +327,10 @@ def interpolate(input,
if data_format == 'NHWC' or data_format == 'NDHWC' or data_format == 'NWC':
data_layout = 'NHWC'
inputs = {"X": input}
if resample == 'NEAREST':
align_corners = False
inputs = {"X": x}
attrs = {
"out_d": -1,
"out_h": -1,
......@@ -408,7 +379,7 @@ def interpolate(input,
size_list.append(dim)
inputs['SizeTensor'] = new_size_tensor
if len(input.shape) == 3:
if len(x.shape) == 3:
if len(out_shape) != 1:
raise ValueError(
"out_shape length should be 2 for input 3-D tensor")
......@@ -417,7 +388,7 @@ def interpolate(input,
else:
out_shape = list(map(int, out_shape))
attrs['out_w'] = out_shape[0]
if len(input.shape) == 4:
if len(x.shape) == 4:
if len(out_shape) != 2:
raise ValueError("out_shape length should be 2 for "
"input 4-D tensor.")
......@@ -428,7 +399,7 @@ def interpolate(input,
out_shape = list(map(int, out_shape))
attrs['out_h'] = out_shape[0]
attrs['out_w'] = out_shape[1]
if len(input.shape) == 5:
if len(x.shape) == 5:
if len(out_shape) != 3:
raise ValueError("out_shape length should be 3 for "
"input 5-D tensor.")
......@@ -449,20 +420,242 @@ def interpolate(input,
elif isinstance(scale, float) or isinstance(scale, int):
if scale <= 0:
raise ValueError("Attr(scale) should be greater than zero.")
attrs['scale'] = float(scale)
scale_list = []
for i in range(len(x.shape) - 2):
scale_list.append(scale)
attrs['scale'] = list(map(float, scale_list))
elif isinstance(scale, list):
if len(scale) != len(x.shape) - 2:
raise ValueError("scale_shape length should be {} for "
"input {}-D tensor.".format(
len(x.shape) - 2, len(x.shape)))
for value in scale:
if value <= 0:
raise ValueError("Attr(scale) should be greater than zero.")
attrs['scale'] = list(map(float, scale))
else:
raise TypeError(
"Attr(scale)'s type should be float, int or Variable.")
"Attr(scale)'s type should be float, int, list or Tensor.")
if in_dygraph_mode():
attr_list = []
for k, v in attrs.items():
attr_list.append(k)
attr_list.append(v)
dy_attr = tuple(attr_list)
if resample_type == "linear":
out = core.ops.linear_interp_v2(x, *dy_attr)
if resample_type == "bilinear":
out = core.ops.bilinear_interp_v2(x, *dy_attr)
if resample_type == "trilinear":
out = core.ops.trilinear_interp_v2(x, *dy_attr)
if resample_type == "nearest":
out = core.ops.nearest_interp_v2(x, *dy_attr)
if resample_type == "bicubic":
out = core.ops.bicubic_interp_v2(x, *dy_attr)
return out
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='{}_interp'.format(resample_type),
type='{}_interp_v2'.format(resample_type),
inputs=inputs,
outputs={"Out": out},
attrs=attrs)
return out
def upsample(x,
size=None,
scale_factor=None,
mode='nearest',
align_corners=False,
align_mode=0,
data_format='NCHW',
name=None):
"""
This op resizes a batch of images.
The input must be a 3-D Tensor of the shape (num_batches, channels, in_w)
or 4-D (num_batches, channels, in_h, in_w), or a 5-D Tensor of the shape
(num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
and the resizing only applies on the three dimensions(depth, height and width).
Supporting resample methods:
'linear' : Linear interpolation
'bilinear' : Bilinear interpolation
'trilinear' : Trilinear interpolation
'nearest' : Nearest neighbor interpolation
'bicubic' : Bicubic interpolation
Linear interpolation is the method of using a line connecting two known quantities
to determine the value of an unknown quantity between the two known quantities.
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in both the 3rd dimension(in height direction) and the 4th dimension(in width
direction) on input tensor.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
Bicubic interpolation is an extension of cubic interpolation for interpolating
data points on a two-dimensional regular grid. The interpolated surface is
smoother than corresponding surfaces obtained by bilinear interpolation or
nearest-neighbor interpolation.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
align_corners and align_mode are optional parameters,the calculation method
of interpolation can be selected by them.
Example:
.. code-block:: text
For scale_factor:
if align_corners = True && out_size > 1 :
scale_factor = (in_size-1.0)/(out_size-1.0)
else:
scale_factor = float(in_size/out_size)
Linear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,W_in)
output: (N,C,W_out) where:
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,W_in)
output: (N,C,W_out) where:
W_out = W_{in} * scale_{factor}
Nearest neighbor interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = floor (H_{in} * scale_{factor})
W_out = floor (W_{in} * scale_{factor})
else:
align_corners = True
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = round(H_{in} * scale_{factor})
W_out = round(W_{in} * scale_{factor})
Bilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Bicubic interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
https://en.wikipedia.org/wiki/Linear_interpolation.
For details of linear interpolation, please refer to Wikipedia:
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation.
For details of bicubic interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bicubic_interpolation
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation.
Parameters:
x (Tensor): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
size (list|tuple|Tensor|None): Output shape of image resize
layer, the shape is (out_w, ) when input is a 3-D Tensor, the shape is (out_h, out_w)
when input is a 4-D Tensor and is (out_d, out_h, out_w) when input is a 5-D Tensor.
Default: None. If a list, each element can be an integer or a Tensor Variable of shape: [1].
If a Tensor Variable, its dimensions size should be a 1.
scale_factor (float|Tensor|list|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale_factor` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale_factor`.
Default: None.
mode (str): The resample method. It supports 'linear', 'nearest', 'bilinear',
'bicubic' and 'trilinear' currently. Default: 'nearest'
align_corners(bool) : An optional bool, If True, the centers of the 4 corner pixels of the
input and output tensors are aligned, preserving the values at the
corner pixels.
Default: False
align_mode(int) : An optional for linear/bilinear/trilinear interpolation. Refer to the formula in the example above,
it can be \'0\' for src_idx = scale_factor*(dst_indx+0.5)-0.5 , can be \'1\' for
src_idx = scale_factor*dst_index.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
`"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
name(str, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`
Returns:
A 3-D Tensor of the shape (num_batches, channels, out_w) or (num_batches, out_w, channels),
A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
Raises:
TypeError: size should be a list or tuple or Tensor.
ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear',
'trilinear', 'bicubic', or 'nearest' currently.
ValueError: 'linear' only support 3-D tensor.
ValueError: 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.
ValueError: 'trilinear' only support 5-D tensor.
ValueError: One of size and scale_factor must not be None.
ValueError: size length should be 1 for input 3-D tensor.
ValueError: size length should be 2 for input 4-D tensor.
ValueError: size length should be 3 for input 5-D tensor.
ValueError: scale_factor should be greater than zero.
TypeError: align_corners should be a bool value
ValueError: align_mode can only be '0' or '1'
ValueError: data_format can only be 'NCW', 'NWC', 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'.
Examples:
.. code-block:: python
import paddle
import numpy as np
import paddle.nn.functional as F
paddle.disable_static()
input = paddle.to_tensor(input_data)
output = F.upsample(input=input, size=[12,12])
print(output.shape)
# [2L, 3L, 12L, 12L]
"""
return interpolate(x, size, scale_factor, mode, align_corners, align_mode,
data_format)
def bilinear(x1, x2, weight, bias=None, name=None):
"""
......
......@@ -58,6 +58,8 @@ from .common import Embedding #DEFINE_ALIAS
from .common import Linear #DEFINE_ALIAS
from .common import Flatten #DEFINE_ALIAS
from .common import UpSample #DEFINE_ALIAS
from .common import UpsamplingNearest2d #DEFINE_ALIAS
from .common import UpsamplingBilinear2d #DEFINE_ALIAS
from .common import Dropout #DEFINE_ALIAS
from .common import Dropout2D #DEFINE_ALIAS
from .common import Dropout3D #DEFINE_ALIAS
......
......@@ -29,6 +29,8 @@ __all__ = [
'Linear',
'UpSample',
'Pad2D',
'UpsamplingNearest2d',
'UpsamplingBilinear2d',
'ReflectionPad1d',
'ReplicationPad1d',
'ConstantPad1d',
......@@ -54,8 +56,7 @@ class UpSample(layers.Layer):
or 4-D (num_batches, channels, in_h, in_w), or a 5-D Tensor of the shape
(num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
and the resizing only applies on the three dimensions(depth, height and width).
**Warning:** the parameter :attr:`actual_shape` will be deprecated in the
future and only use :attr:`out_shape` instead.
Supporting resample methods:
'linear' : Linear interpolation
'bilinear' : Bilinear interpolation
......@@ -85,7 +86,7 @@ class UpSample(layers.Layer):
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optional parameters,the calculation method
align_corners and align_mode are optional parameters,the calculation method
of interpolation can be selected by them.
Example:
......@@ -183,16 +184,16 @@ class UpSample(layers.Layer):
https://en.wikipedia.org/wiki/Trilinear_interpolation.
Parameters:
input (Variable): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
x (Tensor): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
size (list|tuple|Variable|None): Output shape of image resize
size (list|tuple|Tensor|None): Output shape of image resize
layer, the shape is (out_w, ) when input is a 3-D Tensor, the shape is (out_h, out_w)
when input is a 4-D Tensor and is (out_d, out_h, out_w) when input is a 5-D Tensor.
Default: None. If a list, each element can be an integer or a Tensor Variable of shape: [1].
If a Tensor Variable, its dimensions size should be a 1.
scale_factor (float|Variable|None): The multiplier for the input height or width. At
scale_factor (float|Tensor|list|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale_factor` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale_factor`.
And :attr:`out_shape` has a higher priority than :attr:`scale_factor`.Has to match input size if it is a list.
Default: None.
mode (str): The resample method. It supports 'linear', 'nearst', 'bilinear',
'bicubic' and 'trilinear' currently. Default: 'nearest'
......@@ -216,7 +217,7 @@ class UpSample(layers.Layer):
A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
Raises:
TypeError: size should be a list or tuple or Variable.
TypeError: size should be a list or tuple or Tensor.
ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear',
'trilinear', 'bicubic', or 'nearest' currently.
ValueError: 'linear' only support 3-D tensor.
......@@ -234,16 +235,18 @@ class UpSample(layers.Layer):
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
import numpy as np
import paddle.fluid.dygraph as dg
upsample_op = paddle.nn.UpSample(size=[12,12])
paddle.disable_static()
input_data = np.random.rand(2,3,6,10).astype("float32")
place = paddle.fluid.CPUPlace()
with dg.guard(place) as g:
input = dg.to_variable(input_data)
output = upsample_op(input=input)
upsample_out = paddle.nn.UpSample(size=[12,12])
input = paddle.to_tensor(input_data)
output = upsample_out(x=input)
print(output.shape)
# [2L, 3L, 12L, 12L]
"""
def __init__(self,
......@@ -251,8 +254,9 @@ class UpSample(layers.Layer):
scale_factor=None,
mode='nearest',
align_corners=False,
align_mode=1,
data_format='NCHW'):
align_mode=0,
data_format='NCHW',
name=None):
super(UpSample, self).__init__()
self.size = size
self.scale_factor = scale_factor
......@@ -260,16 +264,184 @@ class UpSample(layers.Layer):
self.align_corners = align_corners
self.align_mode = align_mode
self.data_format = data_format
self.name = name
def forward(self, input):
def forward(self, x):
out = F.interpolate(
input,
x,
size=self.size,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
align_mode=self.align_mode,
data_format=self.data_format)
data_format=self.data_format,
name=self.name)
return out
class UpsamplingNearest2d(layers.Layer):
"""
This op upsamples a batch of images, using nearest neighbours' pixel values.
The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w),
and the upsampling only applies on the two dimensions(height and width).
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in both the 3rd dimension(in height direction) and the 4th dimension(in width
direction) on input tensor.
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
x (Tensor): 4-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
size (list|tuple|Tensor|None): Output shape of image resize
layer, the shape is (out_h, out_w) when input is a 4-D Tensor.
Default: None. If a list, each element can be an integer or a Tensor Variable of shape: [1].
If a Tensor Variable, its dimensions size should be a 1.
scale_factor (float|int|list|Tensor|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale_factor` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale_factor`.
Default: None. Has to match input size if it is a list.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
`"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
name(str, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`
Returns:
A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
Raises:
TypeError: size should be a list or tuple or Tensor.
ValueError: 'nearest' only support 4-D tensor.
ValueError: One of size and scale_factor must not be None.
ValueError: size length should be 2 for input 4-D tensor.
ValueError: scale_factor should be greater than zero.
ValueError: data_format can only be 'NCHW', 'NHWC'.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
import numpy as np
paddle.disable_static()
input_data = np.random.rand(2,3,6,10).astype("float32")
upsample_out = paddle.nn.UpsamplingNearest2d(size=[12,12])
input = paddle.to_tensor(input_data)
output = upsample_out(x=input)
print(output.shape)
# [2L, 3L, 12L, 12L]
"""
def __init__(self,
size=None,
scale_factor=None,
data_format='NCHW',
name=None):
super(UpsamplingNearest2d, self).__init__()
self.size = size
self.scale_factor = scale_factor
self.data_format = data_format
self.name = name
def forward(self, x):
out = F.interpolate(
x,
size=self.size,
scale_factor=self.scale_factor,
mode='nearest',
align_corners=False,
align_mode=0,
data_format=self.data_format,
name=self.name)
return out
class UpsamplingBilinear2d(layers.Layer):
"""
This op upsamples a batch of images, using bilinear' pixel values.
The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w),
and the upsampling only applies on the two dimensions(height and width).
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation.
x (Tensor): 4-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
size (list|tuple|Tensor|None): Output shape of image resize
layer, the shape is (out_h, out_w) when input is a 4-D Tensor.
Default: None. If a list, each element can be an integer or a Tensor Variable of shape: [1].
If a Tensor Variable, its dimensions size should be a 1.
scale_factor (float|int|list|Tensor|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale_factor` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale_factor`.
Default: None. Has to match input size if it is a list.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
`"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
name(str, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`
Returns:
A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
Raises:
TypeError: size should be a list or tuple or Tensor.
ValueError: 'bilinear' only support 4-D tensor.
ValueError: One of size and scale_factor must not be None.
ValueError: size length should be 2 for input 4-D tensor.
ValueError: scale_factor should be greater than zero.
ValueError: data_format can only be 'NCHW', 'NHWC'.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
import numpy as np
paddle.disable_static()
input_data = np.random.rand(2,3,6,10).astype("float32")
upsample_out = paddle.nn.UpsamplingBilinear2d(size=[12,12])
input = paddle.to_tensor(input_data)
output = upsample_out(x=input)
print(output.shape)
# [2L, 3L, 12L, 12L]
"""
def __init__(self,
size=None,
scale_factor=None,
data_format='NCHW',
name=None):
super(UpsamplingBilinear2d, self).__init__()
self.size = size
self.scale_factor = scale_factor
self.data_format = data_format
self.name = name
def forward(self, x):
out = F.interpolate(
x,
size=self.size,
scale_factor=self.scale_factor,
mode='bilinear',
align_corners=True,
align_mode=0,
data_format=self.data_format,
name=self.name)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册