未验证 提交 7fedf26b 编写于 作者: F FDInSky 提交者: GitHub

add linear interpolate operator (#23357)

* test=develop add linear interpolate operator
上级 e24575c8
......@@ -21,6 +21,85 @@ namespace operators {
using framework::Tensor;
using DataLayout = framework::DataLayout;
static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE_EQ("linear", interp_method,
platform::errors::InvalidArgument(
"Interpolation method can only be \"linear\" when"
"Input(X) dimension is 3, but got method = %s .",
interp_method));
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
if (ctx->HasInputs("SizeTensor")) {
// top prority size
auto inputs_name = ctx->Inputs("SizeTensor");
PADDLE_ENFORCE_EQ(
inputs_name.size(), 1,
platform::errors::InvalidArgument(
"Input(SizeTensor)'size of Op(interpolate) must be 1. "
"Attr(out_shape)'s length must be 1 for 3-D input tensor, but got "
"size = %d .",
inputs_name.size()));
int out_w = ctx->Attrs().Get<int>("out_w");
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_w};
} else {
dim_out = {dim_x[0], out_w, dim_x[2]};
}
ctx->SetOutputDim("Out", dim_out);
return;
}
int out_w;
if (ctx->HasInput("Scale")) {
auto scale_tensor = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ(
scale_tensor.size(), 1,
platform::errors::InvalidArgument(
"Scale's dimension size must be 1, but got dimension = %d .",
scale_tensor.size()));
out_w = -1;
} else {
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
// round down
out_w = (data_layout == DataLayout::kNCHW
? static_cast<int>(dim_x[2] * scale)
: static_cast<int>(dim_x[1] * scale));
// protect when input shape is -1
out_w = out_w > 0 ? out_w : -1;
} else {
out_w = ctx->Attrs().Get<int>("out_w");
}
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(
out_size_dim.size(), 1,
platform::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got dimention = %d .",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(out_size_dim[0], 1, platform::errors::InvalidArgument(
"OutSize's dim[0] must be 1"));
ctx->ShareLoD("X", "Out");
return;
}
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {dim_x[0], dim_x[1], out_w};
} else {
dim_out = {dim_x[0], out_w, dim_x[2]};
}
ctx->SetOutputDim("Out", dim_out);
}
static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
......@@ -29,7 +108,8 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
"bilinear" == interp_method || "nearest" == interp_method ||
"bicubic" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\" when "
"Input(X) dimension is 4");
"Input(X) dimension is 4, but got method = %s .",
interp_method);
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
......@@ -38,8 +118,11 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
auto inputs_name = ctx->Inputs("SizeTensor");
PADDLE_ENFORCE_EQ(
inputs_name.size(), 2,
"Input(SizeTensor)'size of Op(interpolate) must be 2. "
"Attr(out_shape)'s length must be 2 for 4-D input tensor.");
platform::errors::InvalidArgument(
"Input(SizeTensor)'size of Op(interpolate) must be 2. "
"Attr(out_shape)'s length must be 2 for 4-D input "
"tensor, but got size = %d .",
inputs_name.size()));
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
framework::DDim dim_out;
......@@ -56,8 +139,11 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
int out_h, out_w;
if (ctx->HasInput("Scale")) {
auto scale_tensor = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ(scale_tensor.size(), 1,
"Scale's dimension size must be 1.");
PADDLE_ENFORCE_EQ(
scale_tensor.size(), 1,
platform::errors::InvalidArgument(
"Scale's dimension size must be 1, but got dimension = %d .",
scale_tensor.size()));
out_h = -1;
out_w = -1;
} else {
......@@ -81,9 +167,16 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
PADDLE_ENFORCE_EQ(
out_size_dim.size(), 1,
platform::errors::InvalidArgument(
"OutSize's dimension size must be 1, but got dimension = %d .",
out_size_dim.size()));
PADDLE_ENFORCE_EQ(
out_size_dim[0], 2,
platform::errors::InvalidArgument(
"OutSize's dim[0] must be 2, but got dimention = %d .",
out_size_dim[0]));
ctx->ShareLoD("X", "Out");
return;
}
......@@ -101,9 +194,12 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
auto dim_x = ctx->GetInputDim("X");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE("trilinear" == interp_method,
"Interpolation method can only be \"trilinear\" when Input(X) "
"dimension is 5");
PADDLE_ENFORCE_EQ(
"trilinear", interp_method,
platform::errors::InvalidArgument(
"Interpolation method can only be \"trilinear\" when Input(X) "
"dimension is 5, but got method = %s .",
interp_method));
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
......@@ -112,8 +208,11 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
auto inputs_name = ctx->Inputs("SizeTensor");
PADDLE_ENFORCE_EQ(
inputs_name.size(), 3,
"Input(SizeTensor)'s size of Op(interpolate) must be 3. "
"Attr(out_shape)'s length must be 3 for 5-D input tensor.");
platform::errors::InvalidArgument(
"Input(SizeTensor)'s size of Op(interpolate) must be 3. "
"Attr(out_shape)'s length must be 3 for 5-D input "
"tensor, but got size = %d .",
inputs_name.size()));
int out_d = ctx->Attrs().Get<int>("out_d");
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
......@@ -131,8 +230,11 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
int out_d, out_h, out_w;
if (ctx->HasInput("Scale")) {
auto scale_tensor = ctx->GetInputDim("Scale");
PADDLE_ENFORCE_EQ(scale_tensor.size(), 1,
"Scale's dimension size must be 1");
PADDLE_ENFORCE_EQ(
scale_tensor.size(), 1,
platform::errors::InvalidArgument(
"Scale's dimension size must be 1, but got size = %d .",
scale_tensor.size()));
out_d = -1;
out_h = -1;
out_w = -1;
......@@ -163,8 +265,11 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 3, "OutSize's dim[0] must be 3");
"OutSize's dimension size must be 1, but got size =%d .",
out_size_dim.size());
PADDLE_ENFORCE_EQ(out_size_dim[0], 3,
"OutSize's dim[0] must be 3, but got size = %d .",
out_size_dim[0]);
ctx->ShareLoD("X", "Out");
return;
}
......@@ -190,10 +295,16 @@ class InterpolateOp : public framework::OperatorWithKernel {
"Output(Out) of InterpolationOp should not be null.");
auto dim_x = ctx->GetInputDim("X"); // NCHW format
PADDLE_ENFORCE(dim_x.size() == 4 || dim_x.size() == 5,
"Input(X) dimension must be 4 or 5");
if (dim_x.size() == 4) {
PADDLE_ENFORCE(
dim_x.size() == 3 || dim_x.size() == 4 || dim_x.size() == 5,
platform::errors::Unimplemented(
"Input(X) dimension must be 3, 4 or 5, but got dimension = %d .",
dim_x.size()));
if (dim_x.size() == 3) {
// shape check for 1D interpolate for input tensor shape NCHW
Interpolate1DInferShapeCheck(ctx);
} else if (dim_x.size() == 4) {
// shape check for 2D interpolate for input tensor shape NCHW
Interpolate2DInferShapeCheck(ctx);
} else { // dim_x.size() == 5
......@@ -262,7 +373,8 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale factor of interpolate op.").SetDefault(0.);
AddAttr<std::string>("interp_method",
"(string, default \"bilinear\"), interpolation "
"method, can be \"bilinear\" for "
"method, can be \"linear\" for linear interpolation"
",\"bilinear\" for "
"bilinear interpolation, \"trilinear\" for trilinear "
"interpolation and \"nearest\" for nearest "
"neighbor interpolation, and \"bicubic\" for bicubic"
......@@ -284,12 +396,15 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
This operator samples input X to given output shape by using specified
interpolation method, the interpolation methods can be \"nearest\"
for nearest neighbor interpolation and \"bilinear\" for bilinear
interpolation.
interpolation and \"linear\" for linear interpolation..
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in both the 3rd dimension(in height direction) and the 4th dimension(in width
direction) on input tensor.
Linear interpolation is the method of using a line connecting two known quantities
to determine the value of an unknown quantity between the two known quantities.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
......@@ -512,6 +627,16 @@ REGISTER_OP_CPU_KERNEL(trilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
REGISTER_OPERATOR(linear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(linear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(linear_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_grad, ops::InterpolateGradKernel<float>,
......
......@@ -112,6 +112,119 @@ __global__ void KeNearestNeighborInterpBw(
}
}
template <typename T>
__global__ void KeLinearInterpFw(const T* in, const size_t in_img_w,
const size_t input_w, T* out,
const size_t out_img_w, const size_t output_h,
const size_t output_w,
const size_t num_channels, const float ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idx = tid % out_img_w;
} else {
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) {
const T* in_pos =
&in[out_id_h * out_id_w + channel_id * in_img_size + in_img_idx];
// linear interpolation
out[out_id_h * output_w + out_id_w] =
w2lambda * in_pos[0] + w1lambda * in_pos[w_id];
} else {
const T* in_pos =
&in[out_id_h * input_w + in_img_idx * num_channels + channel_id];
// linear interpolation
out[out_id_h * output_w + out_id_w] =
w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels];
}
}
}
template <typename T>
__global__ void KeLinearInterpBw(T* in, const size_t in_img_w,
const size_t input_w, const T* out,
const size_t out_img_w, const size_t output_h,
const size_t output_w,
const size_t num_channels, const T ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idx = tid % out_img_w;
} else {
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5
: ratio_w * out_img_idx;
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size + in_img_idx];
} else {
in_pos = &in[out_id_h * input_w + in_img_idx * num_channels + channel_id];
}
const T* out_pos = &out[out_id_w];
if (data_layout == DataLayout::kNCHW) {
platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id], w1lambda * out_pos[0]);
} else {
platform::CudaAtomicAdd(&in_pos[0], w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
w1lambda * out_pos[0]);
}
}
}
template <typename T>
__global__ void KeBilinearInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
......@@ -706,6 +819,84 @@ __global__ void KeBicubicInterpBw(
}
}
template <typename T>
static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
auto* input_data = input.data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_shape_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_shape_tensor);
out_w = new_size[0];
} else {
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_w = size_data[0];
}
}
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_w};
} else {
dim_out = {n, out_w, c};
}
auto output_data = output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1.0) / (out_w - 1.0)
: static_cast<float>(in_w) / out_w;
}
int in_cw = c * in_w;
int out_cw = c * out_w;
int pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("linear" == interp_method) {
KeLinearInterpFw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_w, in_cw, output_data, out_w, n, out_cw, c, ratio_w,
align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
......@@ -751,12 +942,12 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
out_w = size_data[1];
}
}
PADDLE_ENFORCE_GT(
out_h, 0,
"out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
......@@ -859,15 +1050,15 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
out_w = size_data[2];
}
}
PADDLE_ENFORCE_GT(
out_d, 0,
"out_d in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_h, 0,
"out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(out_d, 0, platform::errors::InvalidArgument(
"out_d in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
......@@ -917,6 +1108,84 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
}
}
template <typename T>
static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopySync(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_w = size_data[0];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_w = new_size[0];
}
auto* output_grad_data = output_grad.data<T>();
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_w};
} else {
dim_grad = {n, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto* input_grad_data = input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
int in_cw = c * in_w;
int out_cw = c * out_w;
int pixelNum = n * out_cw;
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(pixelNum, ctx);
if ("linear" == interp_method) {
KeLinearInterpBw<T><<<config.blocks, config.threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_w, in_cw, output_grad_data, out_w, n, out_cw, c,
ratio_w, align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor output_grad) {
......@@ -1124,13 +1393,16 @@ template <typename T>
class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::NotFound("This kernel only runs on GPU device."));
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
if (input_dims.size() == 4) { // 2D interpolation
if (input_dims.size() == 3) { // 1D interpolation
Interpolate1DCUDAFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDAFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDAFwd<T>(ctx, *input, output);
......@@ -1142,13 +1414,16 @@ template <typename T>
class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::NotFound("This kernel only runs on GPU device."));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto output_grad_dims = output_grad->dims();
if (output_grad_dims.size() == 4) { // 2D interpolation
if (output_grad_dims.size() == 3) { // 1D interpolation
Interpolate1DCUDABwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 4) { // 2D interpolation
Interpolate2DCUDABwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 5) { // 3D interpolation
Interpolate3DCUDABwd<T>(ctx, input_grad, *output_grad);
......@@ -1178,6 +1453,12 @@ REGISTER_OP_CUDA_KERNEL(trilinear_interp, ops::InterpolateOpCUDAKernel<float>,
REGISTER_OP_CUDA_KERNEL(trilinear_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(linear_interp, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(linear_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(bicubic_interp, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
......
......@@ -32,12 +32,12 @@ inline std::vector<int> get_new_shape(
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
"shape of dim tensor should be [1]");
PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}),
platform::errors::InvalidArgument("shape of dim tensor should be [1]"));
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
......@@ -64,7 +64,13 @@ inline void ExtractNCDWH(const framework::DDim& dims,
const DataLayout& data_layout, int* N, int* C, int* D,
int* H, int* W) {
*N = dims[0];
if (dims.size() == 4) {
if (dims.size() == 3) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[2];
*D = 1;
*H = 1;
*W = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
} else if (dims.size() == 4) {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[3];
*D = 1;
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
......@@ -107,6 +113,103 @@ static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
}
}
template <typename T>
static void LinearInterpolation(const Tensor& input, Tensor* output,
const float ratio_w, const int in_w,
const int n, const int c, const int out_w,
const bool align_corners, const bool align_mode,
const DataLayout data_layout) {
auto input_t = EigenTensor<T, 3>::From(input);
auto output_t = EigenTensor<T, 3>::From(*output);
bool align_flag = (align_mode == 0 && !align_corners);
std::vector<int> vx_w, vx_e;
std::vector<float> vd_w, vd_e;
vx_w.reserve(out_w);
vx_e.reserve(out_w);
vd_w.reserve(out_w);
vd_e.reserve(out_w);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int l = 0; l < out_w; l++) {
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0; // w
int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w; // w_id
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; // w1lambda
float d_e = 1.f - d_w; // w2lambda
{
vx_w[l] = x_w;
vx_e[l] = x_e;
vd_w[l] = d_w;
vd_e[l] = d_e;
}
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
for (int l = 0; l < out_w; l++) {
// linear interpolation
T out_t;
if (data_layout == DataLayout::kNCHW) {
out_t = input_t(i, j, vx_w[l]) * vd_e[l] +
input_t(i, j, vx_e[l]) * vd_w[l];
output_t(i, j, l) = out_t;
} else {
out_t = input_t(i, vx_w[l], j) * vd_e[l] +
input_t(i, vx_e[l], j) * vd_w[l];
output_t(i, l, j) = out_t;
}
}
}
}
}
template <typename T>
static void LinearInterpolationGrad(const Tensor& output_grad,
Tensor* input_grad, const float ratio_w,
const int in_w, const int n, const int c,
const int out_w, const bool align_corners,
const int align_mode,
const DataLayout data_layout) {
auto input_grad_t = EigenTensor<T, 3>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 3>::From(output_grad);
bool align_flag = (align_mode == 0 && !align_corners);
for (int l = 0; l < out_w; l++) {
int x_w = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
x_w = (x_w > 0) ? x_w : 0; // w
int x_e = (x_w < (in_w - 1)) ? (x_w + 1) : x_w; // w_id
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; // w1lambda
float d_e = 1.f - d_w; // w2lambda
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// linear interpolation grad
if (data_layout == DataLayout::kNCHW) {
const T grad = output_grad_t(i, j, l);
input_grad_t(i, j, x_w) += static_cast<T>(grad * d_e);
input_grad_t(i, j, x_e) += static_cast<T>(grad * d_w);
} else {
const T grad = output_grad_t(i, l, j);
input_grad_t(i, x_w, j) += static_cast<T>(grad * d_e);
input_grad_t(i, x_e, j) += static_cast<T>(grad * d_w);
}
}
}
}
}
template <typename T>
static void BilinearInterpolation(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
......@@ -666,6 +769,69 @@ static void BicubicInterpolationGrad(const Tensor& output_grad,
}
}
template <typename T>
static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_w = new_size[0];
} else {
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_w = out_size_data[0];
}
}
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_w};
} else {
dim_out = {n, out_w, c};
}
output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_w == out_w) {
framework::TensorCopy(input, ctx.GetPlace(), output);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("linear" == interp_method) {
LinearInterpolation<T>(input, output, ratio_w, in_w, n, c, out_w,
align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
const Tensor& input, Tensor* output) {
......@@ -707,12 +873,12 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
out_w = out_size_data[1];
}
}
PADDLE_ENFORCE_GT(
out_h, 0,
"out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_h, out_w};
......@@ -795,15 +961,15 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
out_w = out_size_data[2];
}
}
PADDLE_ENFORCE_GT(
out_d, 0,
"out_d in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_h, 0,
"out_h in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(
out_w, 0,
"out_w in Attr(out_shape) of Op(interpolate) should be greater than 0.");
PADDLE_ENFORCE_GT(out_d, 0, platform::errors::InvalidArgument(
"out_d in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
......@@ -842,6 +1008,71 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
}
}
template <typename T>
static void Interpolate1DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor& output_grad) {
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_w = ctx.Attr<int>("out_w");
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_w = out_size_data[0];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape(list_new_size_tensor);
out_w = new_size[0];
}
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_w};
} else {
dim_grad = {n, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& device_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_w == out_w) {
framework::TensorCopy(output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_w = 0.f;
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
if ("linear" == interp_method) {
LinearInterpolationGrad<T>(output_grad, input_grad, ratio_w, in_w, n, c,
out_w, align_corners, align_mode, data_layout);
}
}
template <typename T>
static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
Tensor* input_grad, const Tensor& output_grad) {
......@@ -1018,7 +1249,9 @@ class InterpolateKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
if (input_dims.size() == 4) { // 2D interpolation
if (input_dims.size() == 3) { // 1D interpolation
Interpolate1DCPUFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 4) { // 2D interpolation
Interpolate2DCPUFwd<T>(ctx, *input, output);
} else if (input_dims.size() == 5) { // 3D interpolation
Interpolate3DCPUFwd<T>(ctx, *input, output);
......@@ -1034,7 +1267,9 @@ class InterpolateGradKernel : public framework::OpKernel<T> {
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto output_grad_dims = output_grad->dims();
if (output_grad_dims.size() == 4) { // 2D interpolation grad
if (output_grad_dims.size() == 3) { // 1D interpolation grad
Interpolate1DCPUBwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 4) { // 2D interpolation grad
Interpolate2DCPUBwd<T>(ctx, input_grad, *output_grad);
} else if (output_grad_dims.size() == 5) { // 3D interpolation grad
Interpolate3DCPUBwd<T>(ctx, input_grad, *output_grad);
......
......@@ -95,6 +95,7 @@ __all__ = [
'dice_loss',
'image_resize',
'image_resize_short',
'resize_linear',
'resize_bilinear',
'resize_trilinear',
'resize_nearest',
......@@ -6889,7 +6890,8 @@ def image_resize(input,
"""
This op resizes a batch of images.
The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w)
The input must be a 3-D Tensor of the shape (num_batches, channels, in_w)
or a 4-D Tensor of the shape (num_batches, channels, in_h, in_w)
or (num_batches, in_h, in_w, channels), or a 5-D Tensor of the shape
(num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
and the resizing only applies on the three dimensions(depth, height and width).
......@@ -6898,13 +6900,17 @@ def image_resize(input,
future and only use :attr:`out_shape` instead.
Supporting resample methods:
'LINEAR' : Linear interpolation
'BILINEAR' : Bilinear interpolation
'TRILINEAR' : Trilinear interpolation
'NEAREST' : Nearest neighbor interpolation
Linear interpolation is the method of using a line connecting two known quantities
to determine the value of an unknown quantity between the two known quantities.
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in both the 3rd dimension(in height direction) and the 4th dimension(in width
direction) on input tensor.
......@@ -6958,6 +6964,23 @@ def image_resize(input,
H_out = round(H_{in} * scale_{factor})
W_out = round(W_{in} * scale_{factor})
linear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,W_in)
output: (N,C,W_out) where:
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,W_in)
output: (N,C,H_out,W_out) where:
W_out = W_{in} * scale_{factor}
Bilinear interpolation:
if:
......@@ -7061,15 +7084,17 @@ def image_resize(input,
TypeError: actual_shape should either be Variable or None.
ValueError: The 'resample' of image_resize can only be 'BILINEAR',
'TRILINEAR' or 'NEAREST' currently.
ValueError: 'LINEAR' only support 3-D tensor.
ValueError: 'BILINEAR' and 'NEAREST' only support 4-D tensor.
ValueError: 'TRILINEAR' only support 5-D tensor.
ValueError: One of out_shape and scale must not be None.
ValueError: out_shape length should be 1 for input 3-D tensor.
ValueError: out_shape length should be 2 for input 4-D tensor.
ValueError: out_shape length should be 3 for input 5-D tensor.
ValueError: scale should be greater than zero.
TypeError: align_corners should be a bool value
ValueError: align_mode can only be '0' or '1'
ValueError: data_format can only be 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'.
ValueError: data_format can only be 'NCW', 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'.
Examples:
.. code-block:: python
......@@ -7134,19 +7159,24 @@ def image_resize(input,
"""
resample_methods = {
'LINEAR': 'linear',
'BILINEAR': 'bilinear',
'TRILINEAR': 'trilinear',
'NEAREST': 'nearest',
'LINEAR': 'linear',
}
resample = resample.upper()
if resample not in resample_methods:
raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR' "
"The 'resample' of image_resize can only be 'LINEAR', 'BILINEAR', 'TRILINEAR' "
"or 'NEAREST' currently.")
resample_type = resample_methods[resample]
if resample in ['BILINEAR', 'NEAREST'] and len(input.shape) != 4:
if resample == 'LINEAR' and len(input.shape) != 3:
raise ValueError("'LINER only support 3-D tensor.")
elif resample in ['BILINEAR', 'NEAREST'] and len(input.shape) != 4:
raise ValueError("'BILINEAR' and 'NEAREST' only support 4-D tensor.")
if resample == 'TRILINEAR' and len(input.shape) != 5:
elif resample == 'TRILINEAR' and len(input.shape) != 5:
raise ValueError("'TRILINEAR'only support 5-D tensor.")
if not isinstance(align_corners, bool):
......@@ -7159,7 +7189,11 @@ def image_resize(input,
helper = LayerHelper('{}_interp'.format(resample_type), **locals())
dtype = helper.input_dtype()
if len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']:
if len(input.shape) == 3 and data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCHW` or `NHWC` supported for 3-D input.")
elif len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCHW` or `NHWC` supported for 4-D input.")
......@@ -7223,7 +7257,16 @@ def image_resize(input,
size_list.append(dim)
inputs['SizeTensor'] = new_size_tensor
if len(input.shape) == 4:
if len(input.shape) == 3:
if len(out_shape) != 1:
raise ValueError("out_shape length should be 1 for "
"input 3-D tensor.")
if contain_var:
attrs['out_w'] = size_list[0]
else:
out_shape = list(map(int, out_shape))
attrs['out_w'] = out_shape[0]
elif len(input.shape) == 4:
if len(out_shape) != 2:
raise ValueError("out_shape length should be 2 for "
"input 4-D tensor.")
......@@ -7269,7 +7312,6 @@ def image_resize(input,
inputs["OutSize"] = actual_shape
elif actual_shape is not None:
raise TypeError("actual_shape should either be Variable or None.")
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='{}_interp'.format(resample_type),
......@@ -7279,6 +7321,132 @@ def image_resize(input,
return out
@templatedoc(op_type="linear_interp")
def resize_linear(input,
out_shape=None,
scale=None,
name=None,
actual_shape=None,
align_corners=True,
align_mode=1,
data_format='NCHW'):
"""
This op resizes the input by performing linear interpolation based on given
output shape which specified by actual_shape, out_shape and scale
in priority order.
**Warning:** the parameter :attr:`actual_shape` will be deprecated in
the future and only use :attr:`out_shape` instead.
Align_corners and align_mode are optional parameters,the calculation
method of interpolation can be selected by them.
Example:
.. code-block:: text
For scale:
if align_corners = True && out_size > 1 :
scale_factor = (in_size-1.0)/(out_size-1.0)
else:
scale_factor = float(in_size/out_size)
Linear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,W_in)
output: (N,C,W_out) where:
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,W_in)
output: (N,C,W_out) where:
W_out = W_{in} * scale_{factor}
Parameters:
input(Variable): 3-D Tensor(NCHW), its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
out_shape(list|tuple|Variable|None): Output shape of resize linear
layer, the shape is (out_w,). Default: None. If a list, each
element can be an integer or a Tensor Variable with shape: [1]. If a
Tensor Variable, its dimension size should be 1.
scale(float|Variable|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
highest priority. It is recommended to use
:attr:`out_shape` if you want to specify output
shape dynamically, because :attr:`actual_shape`
will be deprecated. When using actual_shape to
specify output shape, one of :attr:`out_shape`
and :attr:`scale` should also be set, otherwise
errors would be occurred in graph constructing stage.
Default: None
align_corners(bool): ${align_corners_comment}
align_mode(bool): ${align_mode_comment}
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: 3-D tensor(NCHW or NHWC).
Examples:
.. code-block:: python
#declarative mode
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="input", shape=[None,3,100])
output = fluid.layers.resize_linear(input=input,out_shape=[50,])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.random.rand(1,3,100).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data},
fetch_list=[output],
return_numpy=True)
print(output_data[0].shape)
# (1, 3, 50)
#imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
output = fluid.layers.resize_linear(input=input, out_shape=[50,])
print(output.shape)
# [1L, 3L, 50L]
"""
return image_resize(input, out_shape, scale, name, 'LINEAR', actual_shape,
align_corners, align_mode, data_format)
@templatedoc(op_type="bilinear_interp")
def resize_bilinear(input,
out_shape=None,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import platform
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.nn.functional import *
def linear_interp_np(input,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
align_mode=0,
data_layout='NCHW'):
if data_layout == "NHWC":
input = np.transpose(input, (0, 2, 1)) # NHWC => NCHW
if out_size is not None:
out_w = out_size[0]
if actual_shape is not None:
out_w = actual_shape[0]
batch_size, channel, in_w = input.shape
ratio_w = 0.0
if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((batch_size, channel, out_w))
for j in range(out_w):
if (align_mode == 0 and not align_corners):
w = int(ratio_w * (j + 0.5) - 0.5)
else:
w = int(ratio_w * j)
w = max(0, w)
wid = 1 if w < in_w - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_w = max(ratio_w * (j + 0.5) - 0.5, 0)
w1lambda = idx_src_w - w
else:
w1lambda = ratio_w * j - w
w2lambda = 1.0 - w1lambda
out[:, :, j] = w2lambda * input[:, :, w] + w1lambda * input[:, :, w +
wid]
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 1)) # NCHW => NHWC
return out.astype(input.dtype)
class TestLinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "linear_interp"
input_np = np.random.random(self.input_shape).astype("float64")
if self.data_layout == "NCHW":
in_w = self.input_shape[2]
else:
in_w = self.input_shape[1]
if self.scale > 0:
out_w = int(in_w * self.scale)
else:
out_w = self.out_w
output_np = linear_interp_np(input_np, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode,
'data_layout': self.data_layout
}
self.outputs = {'Out': output_np}
def test_check_output(self):
if platform.system() == "Linux":
self.check_output(atol=1e-7)
else:
self.check_output(atol=1e-5)
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [1, 3, 100]
self.out_w = 50
self.scale = 0.
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = False
self.align_mode = 1
class TestLinearInterpOpDataLayout(TestLinearInterpOp):
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [1, 3, 100]
self.out_w = 50
self.scale = 0.
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = False
self.align_mode = 1
self.data_layout = 'NHWC'
class TestLinearInterpOpAlignMode(TestLinearInterpOp):
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [1, 3, 100]
self.out_w = 50
self.scale = 0.
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = False
self.align_mode = 0
class TestLinearInterpOpScale(TestLinearInterpOp):
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [1, 3, 100]
self.out_w = 50
self.scale = 0.5
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = False
self.align_mode = 0
class TestLinearInterpOpSizeTensor(TestLinearInterpOp):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "linear_interp"
input_np = np.random.random(self.input_shape).astype("float64")
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
if self.data_layout == "NCHW":
in_w = self.input_shape[2]
else:
in_w = self.input_shape[1]
if self.scale > 0:
out_w = int(in_w * self.scale)
else:
out_w = self.out_w
output_np = linear_interp_np(input_np, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None and self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.actual_shape is not None and self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.actual_shape
else:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs = {
'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode,
'data_layout': self.data_layout
}
self.outputs = {'Out': output_np}
class TestLinearInterpOpAPI(unittest.TestCase):
def test_case(self):
x = fluid.data(name="x", shape=[1, 3, 128], dtype="float32")
shape_tensor = fluid.data(name="shape_tensor", shape=[1], dtype="int32")
scale_tensor = fluid.data(
name="scale_tensor", shape=[1], dtype="float32")
dim = fluid.data(name="dim", shape=[1], dtype="int32")
actual_size = fluid.data(name='actual_size', shape=[1], dtype='int32')
out1 = fluid.layers.resize_linear(
x, out_shape=[256, ], align_mode=1, align_corners=False)
out2 = fluid.layers.resize_linear(
x, out_shape=shape_tensor, align_mode=1, align_corners=False)
out3 = fluid.layers.resize_linear(
x, scale=scale_tensor, align_mode=1, align_corners=False)
out4 = fluid.layers.resize_linear(
x, out_shape=[dim, ], align_mode=1, align_corners=False)
out5 = fluid.layers.resize_linear(
x,
out_shape=[256, ],
actual_shape=actual_size,
align_mode=1,
align_corners=False)
x_data = np.random.random((1, 3, 128)).astype("float32")
shape_data = np.array([256, ]).astype("int32")
scale_data = np.array([2.0, ]).astype("float32")
dim_data = np.array([256, ]).astype("int32")
actual_size_data = np.array([256, ]).astype("int32")
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"shape_tensor": shape_data,
"scale_tensor": scale_data,
"dim": dim_data,
'actual_size': actual_size_data,
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = linear_interp_np(
x_data, out_w=256, align_mode=1, align_corners=False)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
class TestLinearInterpOpAPI2_Func(unittest.TestCase):
def test_case(self):
x = fluid.data(name="x", shape=[1, 3, 128], dtype="float32")
shape_tensor = fluid.data(name="shape_tensor", shape=[1], dtype="int32")
scale_tensor = fluid.data(
name="scale_tensor", shape=[1], dtype="float32")
dim = fluid.data(name="dim", shape=[1], dtype="int32")
actual_size = fluid.data(name='actual_size', shape=[1], dtype='int32')
out1 = interpolate(
x,
out_shape=[256, ],
align_mode=1,
align_corners=False,
resample='LINEAR')
out2 = interpolate(
x,
out_shape=shape_tensor,
align_mode=1,
align_corners=False,
resample='LINEAR')
out3 = interpolate(
x,
scale=scale_tensor,
align_mode=1,
align_corners=False,
resample='LINEAR')
out4 = interpolate(
x,
out_shape=[dim, ],
align_mode=1,
align_corners=False,
resample='LINEAR')
out5 = interpolate(
x,
out_shape=[256, ],
actual_shape=actual_size,
align_mode=1,
align_corners=False,
resample='LINEAR')
x_data = np.random.random((1, 3, 128)).astype("float32")
shape_data = np.array([256, ]).astype("int32")
scale_data = np.array([2.0, ]).astype("float32")
dim_data = np.array([256, ]).astype("int32")
actual_size_data = np.array([256, ]).astype("int32")
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"shape_tensor": shape_data,
"scale_tensor": scale_data,
"dim": dim_data,
'actual_size': actual_size_data,
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = linear_interp_np(
x_data, out_w=256, align_mode=1, align_corners=False)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
class TestLinearInterpOpAPI2_0(unittest.TestCase):
def test_case(self):
# dygraph
x_data = np.random.random((1, 3, 128)).astype("float32")
us_1 = paddle.nn.UpSample(
out_shape=[64, ],
resample='LINEAR',
align_mode=1,
align_corners=False)
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(x_data)
interp = us_1(x)
expect = linear_interp_np(
x_data, out_w=64, align_mode=1, align_corners=False)
self.assertTrue(np.allclose(interp.numpy(), expect))
class TestLinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "linear_interp"
input_np = np.random.random(self.input_shape).astype("uint8")
if self.scale > 0:
out_w = int(self.input_shape[3] * self.scale)
else:
out_w = self.out_w
output_np = linear_interp_np(input_np, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {
'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode
}
self.outputs = {'Out': output_np}
def test_check_output(self):
if platform.system() == "Linux":
self.check_output_with_place(place=core.CPUPlace(), atol=1e-7)
else:
self.check_output_with_place(place=core.CPUPlace(), atol=1e-5)
def init_test_case(self):
self.interp_method = 'linear'
self.input_shape = [2, 3, 100]
self.out_w = 50
self.scale = 0.
self.out_size = np.array([50, ]).astype("int32")
self.align_corners = True
self.align_mode = 1
class TestLinearInterpOpException(unittest.TestCase):
def test_exception(self):
def input_shape_error():
x1 = fluid.data(name="x1", shape=[1], dtype="float32")
out = fluid.layers.resize_linear(
x1, out_shape=[256, ], data_format='NCW')
def data_format_error():
x2 = fluid.data(name="x2", shape=[1, 3, 128], dtype="float32")
out = fluid.layers.resize_linear(
x2, out_shape=[256, ], data_format='NHWCD')
def out_shape_error():
x3 = fluid.data(name="x3", shape=[1, 3, 128], dtype="float32")
out = fluid.layers.resize_linear(
x3, out_shape=[
256,
256,
], data_format='NHWC')
self.assertRaises(ValueError, input_shape_error)
self.assertRaises(ValueError, data_format_error)
self.assertRaises(ValueError, out_shape_error)
class TestLinearInterpOpError(unittest.TestCase):
def test_error(self):
with program_guard(Program(), Program()):
def input_shape_error():
x1 = fluid.data(name="x1", shape=[1], dtype="float32")
out1 = paddle.nn.UpSample(
out_shape=[256, ], data_format='NCW', resample='LINEAR')
out1_res = out1(x1)
def data_format_error():
x2 = fluid.data(name="x2", shape=[1, 3, 128], dtype="float32")
out2 = paddle.nn.UpSample(
out_shape=[256, ], data_format='NHWCD', resample='LINEAR')
out2_res = out2(x2)
def out_shape_error():
x3 = fluid.data(name="x3", shape=[1, 3, 128], dtype="float32")
out3 = paddle.nn.UpSample(
out_shape=[
256,
256,
],
data_format='NHWC',
resample='LINEAR')
out3_res = out3(x3)
self.assertRaises(ValueError, input_shape_error)
self.assertRaises(ValueError, data_format_error)
self.assertRaises(ValueError, out_shape_error)
if __name__ == "__main__":
unittest.main()
......@@ -17,10 +17,12 @@
from .layer import norm
from .functional import extension
from .layer import common
__all__ = []
__all__ += norm.__all__
__all__ += extension.__all__
__all__ += common.__all__
# TODO: define alias in nn directory
# from .clip import ErrorClipByValue #DEFINE_ALIAS
......@@ -64,7 +66,7 @@ from .layer.common import BilinearTensorProduct #DEFINE_ALIAS
from .layer.common import Pool2D #DEFINE_ALIAS
from .layer.common import Embedding #DEFINE_ALIAS
from .layer.common import Linear #DEFINE_ALIAS
# from .layer.common import UpSample #DEFINE_ALIAS
from .layer.common import UpSample #DEFINE_ALIAS
from .layer.conv import Conv2D #DEFINE_ALIAS
from .layer.conv import Conv2DTranspose #DEFINE_ALIAS
from .layer.conv import Conv3D #DEFINE_ALIAS
......
......@@ -284,6 +284,7 @@ def interpolate(input,
# [2L, 3L, 12L, 12L]
"""
resample_methods = {
'LINEAR': 'linear',
'BILINEAR': 'bilinear',
'TRILINEAR': 'trilinear',
'NEAREST': 'nearest',
......@@ -291,10 +292,13 @@ def interpolate(input,
}
if resample not in resample_methods:
raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR', 'TRILINEAR', "
"The 'resample' of image_resize can only be 'LINEAR', 'BILINEAR', 'TRILINEAR', "
" 'BICUBIC' or 'NEAREST' currently.")
resample_type = resample_methods[resample]
if resample in ['LINEAR'] and len(input.shape) != 3:
raise ValueError("'LINEAR' only support 3-D tensor.")
if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(input.shape) != 4:
raise ValueError(
"'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor.")
......@@ -311,7 +315,11 @@ def interpolate(input,
helper = LayerHelper('{}_interp'.format(resample_type), **locals())
dtype = helper.input_dtype()
if len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']:
if len(input.shape) == 3 and data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCHW` or `NHWC` supported for 3-D input.")
elif len(input.shape) == 4 and data_format not in ['NCHW', 'NHWC']:
raise ValueError(
"Got wrong value for param `data_format`: " + data_format +
" received but only `NCHW` or `NHWC` supported for 4-D input.")
......@@ -375,6 +383,15 @@ def interpolate(input,
size_list.append(dim)
inputs['SizeTensor'] = new_size_tensor
if len(input.shape) == 3:
if len(out_shape) != 1:
raise ValueError(
"out_shape length should be 2 for input 3-D tensor")
if contain_var:
attrs['out_w'] = size_list[0]
else:
out_shape = list(map(int, out_shape))
attrs['out_w'] = out_shape[0]
if len(input.shape) == 4:
if len(out_shape) != 2:
raise ValueError("out_shape length should be 2 for "
......
......@@ -37,7 +37,7 @@ from .common import BilinearTensorProduct #DEFINE_ALIAS
from .common import Pool2D #DEFINE_ALIAS
from .common import Embedding #DEFINE_ALIAS
from .common import Linear #DEFINE_ALIAS
# from .common import UpSample #DEFINE_ALIAS
from .common import UpSample #DEFINE_ALIAS
from .conv import Conv2D #DEFINE_ALIAS
from .conv import Conv2DTranspose #DEFINE_ALIAS
from .conv import Conv3D #DEFINE_ALIAS
......
......@@ -17,11 +17,232 @@ from ...fluid.dygraph import BilinearTensorProduct #DEFINE_ALIAS
from ...fluid.dygraph import Pool2D #DEFINE_ALIAS
from ...fluid.dygraph import Embedding #DEFINE_ALIAS
from ...fluid.dygraph import Linear #DEFINE_ALIAS
from ...fluid.dygraph import layers
from .. import functional as F
__all__ = [
'BilinearTensorProduct',
'Pool2D',
'Embedding',
'Linear',
# 'UpSample'
]
__all__ = ['BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample']
class UpSample(layers.Layer):
"""
This op resizes a batch of images.
The input must be a 3-D Tensor of the shape (num_batches, channels, in_w)
or 4-D (num_batches, channels, in_h, in_w), or a 5-D Tensor of the shape
(num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
and the resizing only applies on the three dimensions(depth, height and width).
**Warning:** the parameter :attr:`actual_shape` will be deprecated in the
future and only use :attr:`out_shape` instead.
Supporting resample methods:
'LINEAR' : linear interpolation
'BILINEAR' : Bilinear interpolation
'TRILINEAR' : Trilinear interpolation
'NEAREST' : Nearest neighbor interpolation
'BICUBIC' : Bicubic interpolation
Linear interpolation is the method of using a line connecting two known quantities
to determine the value of an unknown quantity between the two known quantities.
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in both the 3rd dimension(in height direction) and the 4th dimension(in width
direction) on input tensor.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
Trilinear interpolation is an extension of linear interpolation for
interpolating functions of three variables (e.g. D-direction,
H-direction and W-direction in this op) on a rectilinear 3D grid.
The linear interpolation is performed on three directions.
Align_corners and align_mode are optional parameters,the calculation method
of interpolation can be selected by them.
Bicubic interpolation is an extension of cubic interpolation for interpolating
data points on a two-dimensional regular grid. The interpolated surface is
smoother than corresponding surfaces obtained by bilinear interpolation or
nearest-neighbor interpolation.
Example:
.. code-block:: text
For scale:
if align_corners = True && out_size > 1 :
scale_factor = (in_size-1.0)/(out_size-1.0)
else:
scale_factor = float(in_size/out_size)
Nearest neighbor interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = floor (H_{in} * scale_{factor})
W_out = floor (W_{in} * scale_{factor})
else:
align_corners = True
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = round(H_{in} * scale_{factor})
W_out = round(W_{in} * scale_{factor})
Linear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,W_in)
output: (N,C,W_out) where:
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,W_in)
output: (N,C,W_out) where:
W_out = W_{in} * scale_{factor}
Bilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Bicubic interpolation:
if:
align_corners = False
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,H_in,W_in)
output: (N,C,H_out,W_out) where:
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
Trilinear interpolation:
if:
align_corners = False , align_mode = 0
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = (D_{in}+0.5) * scale_{factor} - 0.5
H_out = (H_{in}+0.5) * scale_{factor} - 0.5
W_out = (W_{in}+0.5) * scale_{factor} - 0.5
else:
input : (N,C,D_in,H_in,W_in)
output: (N,C,D_out,H_out,W_out) where:
D_out = D_{in} * scale_{factor}
H_out = H_{in} * scale_{factor}
W_out = W_{in} * scale_{factor}
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Linear_interpolation.
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation.
For details of trilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Trilinear_interpolation.
For details of bicubic interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bicubic_interpolation
Parameters:
input (Variable): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
its data format is specified by :attr:`data_format`.
out_shape(list|tuple|Variable|None): Output shape of image resize
layer, the shape is (out_w, ) when input is 3-D Tensor ,
the shape is (out_h, out_w) when input is a 4-D Tensor and is
(out_d, out_h, out_w) when input is a 5-D Tensor. Default: None. If
a list, each element can be an integer or a Tensor Variable of shape: [1].
If a Tensor Variable, its dimensions size should be a 1.
scale(float|Variable|None): The multiplier for the input height or width. At
least one of :attr:`out_shape` or :attr:`scale` must be set.
And :attr:`out_shape` has a higher priority than :attr:`scale`.
Default: None.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
resample(str): The resample method. It supports 'LINEAR', 'BILINEAR', 'TRILINEAR' ,
'BICUBIC' and 'NEAREST' currently. Default: 'BILINEAR'
align_corners(bool) : An optional bool, If True, the centers of the 4 corner pixels of the
input and output tensors are aligned, preserving the values at the
corner pixels.
Default: True
align_mode(int) : An optional for bilinear interpolation. can be \'0\'
for src_idx = scale*(dst_indx+0.5)-0.5 , can be \'1\' for
src_idx = scale*dst_index.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from:'NCW', `"NCHW"`, `"NHWC"`, `"NCDHW"`,
`"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
Returns:
A 3-D Tensor of the shape (num_batches, channels, out_w) or (num_batches, out_w, channels),
A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
Raises:
TypeError: out_shape should be a list or tuple or Variable.
TypeError: actual_shape should either be Variable or None.
ValueError: The 'resample' of image_resize can only be 'BILINEAR',
'TRILINEAR', 'BICUBIC', or 'NEAREST' currently.
ValueError: 'BILINEAR', 'BICUBIC' and 'NEAREST' only support 4-D tensor.
ValueError: 'TRILINEAR' only support 5-D tensor.
ValueError: One of out_shape and scale must not be None.
ValueError: out_shape length should be 2 for input 4-D tensor.
ValueError: out_shape length should be 3 for input 5-D tensor.
ValueError: scale should be greater than zero.
TypeError: align_corners should be a bool value
ValueError: align_mode can only be '0' or '1'
ValueError: data_format can only be 'NCW', 'NCHW', 'NHWC', 'NCDHW' or 'NDHWC'.
Examples:
.. code-block:: python
import paddle
import numpy as np
import paddle.fluid.dygraph as dg
upsample_op = paddle.nn.UpSample(out_shape=[12,12])
input_data = np.random.rand(2,3,6,10).astype("float32")
place = paddle.fluid.CPUPlace()
with dg.guard(place) as g:
input = dg.to_variable(input_data)
output = upsample_op(input=input)
print(output.shape)
# [2L, 3L, 12L, 12L]
"""
def __init__(self,
out_shape=None,
scale=None,
resample='BILINEAR',
align_corners=True,
align_mode=1,
data_format='NCHW'):
super(UpSample, self).__init__()
self.out_shape = out_shape
self.scale = scale
self.resample = resample
self.align_corners = align_corners
self.align_mode = align_mode
self.data_format = data_format
def forward(self, input):
out = F.interpolate(
input,
out_shape=self.out_shape,
scale=self.scale,
resample=self.resample,
align_corners=self.align_corners,
align_mode=self.align_mode,
data_format=self.data_format)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册