提交 34bfae24 编写于 作者: D dengkaipeng

Add Interpolate operation. test=develop

上级 df4a3544
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bilinear_interp_op.h"
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class BilinearInterpOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of BilinearInterOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of BilinearInterOp should not be null.");
auto dim_x = ctx->GetInputDim("X"); // NCHW format
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
if (ctx->HasInput("OutSize")) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
}
};
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)");
AddInput("OutSize",
"This is a 1-D tensor with two number. "
"The first number is height and the second number is width.")
.AsDispensable();
AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)");
AddAttr<int>("out_h", "output height of bilinear interpolation op.");
AddAttr<int>("out_w", "output width of bilinear interpolation op.");
AddComment(R"DOC(
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid.
The key idea is to perform linear interpolation first in one
direction, and then again in the other direction.
For details, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
)DOC");
}
};
class BilinearInterpOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp,
ops::BilinearInterpOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>,
ops::BilinearInterpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
ops::BilinearInterpGradKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bilinear_interp_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T>
__global__ void KeBilinearInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratioW) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = ratio_h * out_img_idy;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T h1lambda = ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w;
int in_img_idx = ratioW * out_img_idx;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T w1lambda = ratioW * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
// bilinear interpolation
out[out_id_h * output_w + out_id_w] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
}
}
template <typename T>
__global__ void KeBilinearInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratioW) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = ratio_h * out_img_idy;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T h1lambda = ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w;
int in_img_idx = ratioW * out_img_idx;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T w1lambda = ratioW * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w];
atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]);
atomicAdd(&in_pos[h_id * in_img_w + w_id],
h1lambda * w1lambda * out_pos[0]);
}
}
template <typename T>
class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
auto* input = input_t->data<T>();
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_dims = output_t->dims();
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto* output = output_t->mutable_data<T>(
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
int batch_size = input_t->dims()[0];
int channels = input_t->dims()[1];
int in_h = input_t->dims()[2];
int in_w = input_t->dims()[3];
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(output, input, input_t->numel() * sizeof(T));
} else {
int threadNum = batch_size * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
KeBilinearInterpFw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input, in_h, in_w, batch_size, in_chw, output, out_h, out_w,
batch_size, out_chw, channels, ratio_h, ratio_w);
}
}
};
template <typename T>
class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_output = d_output_t->data<T>();
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, d_input_t, static_cast<T>(0.0));
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
int batch_size = d_input_t->dims()[0];
int channels = d_input_t->dims()[1];
int in_h = d_input_t->dims()[2];
int in_w = d_input_t->dims()[3];
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
} else {
int threadNum = batch_size * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
KeBilinearInterpBw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w,
batch_size, out_chw, channels, ratio_h, ratio_w);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(bilinear_interp,
ops::BilinearInterpOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
ops::BilinearInterpGradOpCUDAKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class BilinearInterpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
auto out_dims = output_t->dims();
auto* input = input_t->data<T>();
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
auto out_size_data = out_size_t->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
auto* output = output_t->mutable_data<T>(
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
int batch_size = input_t->dims()[0];
int channels = input_t->dims()[1];
int in_h = input_t->dims()[2];
int in_w = input_t->dims()[3];
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(output, input, input_t->numel() * sizeof(T));
} else {
for (int k = 0; k < batch_size; ++k) { // loop for batches
for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0;
float h1lambda = ratio_h * i - h;
float h2lambda = 1.f - h1lambda;
for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0;
float w1lambda = ratio_w * j - w;
float w2lambda = 1.f - w1lambda;
// calculate four position for bilinear interpolation
const T* in_pos = &input[k * in_chw + h * in_w + w];
T* out_pos = &output[k * out_chw + i * out_w + j];
for (int c = 0; c < channels; ++c) { // loop for channels
// bilinear interpolation
out_pos[0] = static_cast<T>(
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
h1lambda * (w2lambda * in_pos[hid * in_w] +
w1lambda * in_pos[hid * in_w + wid]));
in_pos += in_hw;
out_pos += out_hw;
}
}
}
}
}
}
};
template <typename T>
class BilinearInterpGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_output = d_output_t->data<T>();
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, d_input_t, static_cast<T>(0.0));
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
auto out_size_data = out_size_t->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
int batch_size = d_input_t->dims()[0];
int channels = d_input_t->dims()[1];
int in_h = d_input_t->dims()[2];
int in_w = d_input_t->dims()[3];
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
} else {
for (int k = 0; k < batch_size; ++k) { // loop for batches
for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0;
float h1lambda = ratio_h * i - h;
float h2lambda = 1 - h1lambda;
for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0;
float w1lambda = ratio_w * j - w;
float w2lambda = 1 - w1lambda;
T* in_pos = &d_input[k * in_chw + h * in_w + w];
const T* out_pos = &d_output[k * out_chw + i * out_w + j];
for (int c = 0; c < channels; ++c) { // loop for channels
in_pos[0] += static_cast<T>(h2lambda * w2lambda * out_pos[0]);
in_pos[wid] += static_cast<T>(h2lambda * w1lambda * out_pos[0]);
in_pos[hid * in_w] +=
static_cast<T>(h1lambda * w2lambda * out_pos[0]);
in_pos[hid * in_w + wid] +=
static_cast<T>(h1lambda * w1lambda * out_pos[0]);
in_pos += in_hw;
out_pos += out_hw;
}
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -9,7 +9,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/nearest_neighbor_interp_op.h"
#include "paddle/fluid/operators/interpolate_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......@@ -18,16 +19,21 @@ namespace operators {
using framework::Tensor;
class NearestNeighborInterpOp : public framework::OperatorWithKernel {
class InterpolateOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of NearestNeighborInterOp should not be null.");
"Input(X) of InterpolateOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of NearestNeighborInterOp should not be null.");
"Output(Out) of InterpolationOp should not be null.");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\".");
auto dim_x = ctx->GetInputDim("X"); // NCHW format
int out_h = ctx->Attrs().Get<int>("out_h");
......@@ -52,33 +58,53 @@ class NearestNeighborInterpOp : public framework::OperatorWithKernel {
}
};
class NearestNeighborInterpOpMaker : public framework::OpProtoAndCheckerMaker {
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of nearest neighbor interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)");
"The input tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, w].");
AddInput("OutSize",
"This is a 1-D tensor with two number. "
"This is a 1-D tensor with two numbers to specify output size. "
"The first number is height and the second number is width.")
.AsDispensable();
AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)");
AddOutput("Out",
"The output tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, W].");
AddAttr<int>("out_h",
"output height of nearest neighbor interpolation op.");
AddAttr<int>("out_w", "output width of nearest neighbor interpolation op.");
AddAttr<int>("out_h", "output height of interpolate op.");
AddAttr<int>("out_w", "output width of interpolate op.");
AddAttr<std::string>(
"interp_method",
"(string), interpolation method, can be \"bilinear\" for "
"bilinear interpolation and \"nearest\" for nearest "
"neighbor interpolation.");
AddComment(R"DOC(
This operator samples input X to given output shape by using specified
interpolation method, the interpolation methods can be \"nearest\"
for nearest neighbor interpolation and \"bilinear\" for bilinear
interpolation.
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in bot the 3rd dimention(in height direction) and the 4th dimention(in width
direction) on input tensor.
For details, please refer to Wikipedia:
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
)DOC");
}
};
class NearestNeighborInterpOpGrad : public framework::OperatorWithKernel {
class InterpolateOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -104,13 +130,11 @@ class NearestNeighborInterpOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(nearest_neighbor_interp, ops::NearestNeighborInterpOp,
ops::NearestNeighborInterpOpMaker,
REGISTER_OPERATOR(interpolate, ops::InterpolateOp, ops::InterpolateOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(nearest_neighbor_interp_grad,
ops::NearestNeighborInterpOpGrad);
REGISTER_OP_CPU_KERNEL(nearest_neighbor_interp,
ops::NearestNeighborInterpKernel<float>,
ops::NearestNeighborInterpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(nearest_neighbor_interp_grad,
ops::NearestNeighborInterpGradKernel<float>);
REGISTER_OPERATOR(interpolate_grad, ops::InterpolateOpGrad);
REGISTER_OP_CPU_KERNEL(interpolate, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(interpolate_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
......@@ -9,7 +9,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/nearest_neighbor_interp_op.h"
#include <string>
#include "paddle/fluid/operators/interpolate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
......@@ -22,7 +23,7 @@ __global__ void KeNearestNeighborInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratio_w) {
const size_t num_channels, const float ratio_h, const float ratio_w) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
......@@ -33,10 +34,10 @@ __global__ void KeNearestNeighborInterpFw(
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = static_cast<int>(round(ratio_h * out_img_idy));
int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);
int out_img_idx = tid % out_img_w;
int in_img_idx = static_cast<int>(round(ratio_w * out_img_idx));
int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
......@@ -48,7 +49,7 @@ __global__ void KeNearestNeighborInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratio_w) {
const size_t num_channels, const float ratio_h, const float ratio_w) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
......@@ -59,28 +60,106 @@ __global__ void KeNearestNeighborInterpBw(
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = static_cast<int>(round(ratio_h * out_img_idy));
int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);
int out_img_idx = tid % out_img_w;
int in_img_idx = static_cast<int>(round(ratio_w * out_img_idx));
int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
const T out_pos = out[out_id_h * output_w + out_id_w];
atomicAdd(in_pos, out_pos);
platform::CudaAtomicAdd(in_pos, out_pos);
}
}
template <typename T>
__global__ void KeBilinearInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = ratio_h * out_img_idy;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T h1lambda = ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w;
int in_img_idx = ratio_w * out_img_idx;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T w1lambda = ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
// bilinear interpolation
out[out_id_h * output_w + out_id_w] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
}
}
template <typename T>
class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel<T> {
__global__ void KeBilinearInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratio_w) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = ratio_h * out_img_idy;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T h1lambda = ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w;
int in_img_idx = ratio_w * out_img_idx;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T w1lambda = ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w];
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w],
h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id],
h1lambda * w1lambda * out_pos[0]);
}
}
template <typename T>
class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto* input = ctx.Input<Tensor>("X"); // float tensor
auto* output = ctx.Output<Tensor>("Out"); // float tensor
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* input_data = input->data<T>();
auto interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size = ctx.Input<Tensor>("OutSize");
......@@ -105,26 +184,35 @@ class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel<T> {
int in_chw = c * in_hw;
int out_chw = c * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(output_data, input_data, input->numel() * sizeof(T));
framework::TensorCopy(*input, ctx.GetPlace(), output);
return;
}
int threadNum = n * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
KeNearestNeighborInterpFw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w);
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w);
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w);
}
}
};
template <typename T>
class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......@@ -137,9 +225,9 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
auto interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
......@@ -159,21 +247,30 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
int in_chw = c * in_hw;
int out_chw = c * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(input_grad, output_grad, input_grad->numel() * sizeof(T));
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
return;
}
int threadNum = n * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
KeNearestNeighborInterpBw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w);
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w);
}
}
};
......@@ -181,7 +278,9 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp,
ops::NearestNeighborInterpOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp_grad,
ops::NearestNeighborInterpGradOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(interpolate, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(interpolate_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
......@@ -10,6 +10,7 @@
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -22,12 +23,126 @@ using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
template <typename T>
class NearestNeighborInterpKernel : public framework::OpKernel<T> {
static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int n, const int c,
const int out_h, const int out_w) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = static_cast<int>(ratio_h * k + 0.5);
for (int l = 0; l < out_w; l++) {
int in_l = static_cast<int>(ratio_w * l + 0.5);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
}
}
}
}
}
template <typename T>
static void BilinearInterpolation(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int in_h, const int in_w, const int n,
const int c, const int out_h,
const int out_w) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
int y_n = static_cast<int>(ratio_h * k);
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float d_n = ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) {
int x_w = static_cast<int>(ratio_w * l);
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float d_w = ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bilinear interpolation
output_t(i, j, k, l) = input_t(i, j, y_n, x_w) * d_s * d_e +
input_t(i, j, y_s, x_w) * d_n * d_e +
input_t(i, j, y_n, x_e) * d_s * d_w +
input_t(i, j, y_s, x_e) * d_n * d_w;
}
}
}
}
}
template <typename T>
static void NearestNeighborInterpolateGrad(const Tensor& output_grad,
Tensor* input_grad,
const float ratio_h,
const float ratio_w, const int n,
const int c, const int out_h,
const int out_w) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = static_cast<int>(ratio_h * k + 0.5);
for (int l = 0; l < out_w; l++) {
int in_l = static_cast<int>(ratio_w * l + 0.5);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
}
}
}
}
}
template <typename T>
static void BilinearInterpolationGrad(const Tensor& output_grad,
Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h,
const int in_w, const int n, const int c,
const int out_h, const int out_w) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
int y_n = static_cast<int>(ratio_h * k);
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float d_n = ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) {
int x_w = static_cast<int>(ratio_w * l);
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float d_w = ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bilinear interpolation grad
const T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
}
}
}
}
}
template <typename T>
class InterpolateKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
std::string interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size = ctx.Input<Tensor>("OutSize");
......@@ -58,30 +173,25 @@ class NearestNeighborInterpKernel : public framework::OpKernel<T> {
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
auto input_t = EigenTensor<T, 4>::From(*input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = static_cast<int>(round(ratio_h * k));
for (int l = 0; l < out_w; l++) {
int in_l = static_cast<int>(round(ratio_w * l));
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
}
}
}
if ("bilinear" == interp_method) {
BilinearInterpolation<T>(*input, output, ratio_h, ratio_w, in_h, in_w, n,
c, out_h, out_w);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolate<T>(*input, output, ratio_h, ratio_w, n, c,
out_h, out_w);
}
}
};
template <typename T>
class NearestNeighborInterpGradKernel : public framework::OpKernel<T> {
class InterpolateGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
std::string interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size = ctx.Input<Tensor>("OutSize");
......@@ -112,18 +222,12 @@ class NearestNeighborInterpGradKernel : public framework::OpKernel<T> {
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(*output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = static_cast<int>(round(ratio_h * k));
for (int l = 0; l < out_w; l++) {
int in_l = static_cast<int>(round(ratio_w * l));
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
}
}
}
if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(*output_grad, input_grad, ratio_h, ratio_w,
in_h, in_w, n, c, out_h, out_w);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolateGrad<T>(*output_grad, input_grad, ratio_h,
ratio_w, n, c, out_h, out_w);
}
}
};
......
......@@ -5612,17 +5612,14 @@ def image_resize(input,
out = fluid.layers.image_resize(input, out_shape=[12, 12])
"""
resample_methods = {
'BILINEAR': 'bilinear_interp',
'NEAREST': 'nearest_neighbor_interp'
}
resample_methods = {'BILINEAR': 'bilinear', 'NEAREST': 'nearest'}
if resample not in resample_methods:
raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR' and 'NEAREST' currently."
)
if out_shape is None and scale is None:
raise ValueError("One of out_shape and scale must not be None")
helper = LayerHelper(resample_methods[resample], **locals())
helper = LayerHelper('interpolate', **locals())
dtype = helper.input_dtype()
def _is_list_or_turple_(data):
......@@ -5647,15 +5644,18 @@ def image_resize(input,
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=resample_methods[resample],
type='interpolate',
inputs=inputs,
outputs={"Out": out},
attrs={"out_h": out_h,
"out_w": out_w})
attrs={
"out_h": out_h,
"out_w": out_w,
"interp_method": resample_methods[resample]
})
return out
@templatedoc(op_type="bilinear_interp")
@templatedoc(op_type="interpolate")
def resize_bilinear(input, out_shape=None, scale=None, name=None):
"""
${comment}
......@@ -5678,7 +5678,7 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
return image_resize(input, out_shape, scale, name, 'BILINEAR')
@templatedoc(op_type="bilinear_interp")
@templatedoc(op_type="interpolate")
def resize_nearest(input, out_shape=None, scale=None, name=None):
"""
${comment}
......
......@@ -20,7 +20,31 @@ from op_test import OpTest
import paddle.fluid.core as core
def nearest_neighbor_interp_np(X, out_h, out_w, out_size=None):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
n, c, in_h, in_w = X.shape
ratio_h = ratio_w = 0.0
if out_h > 1:
ratio_h = (in_h - 1.0) / (out_h - 1.0)
if out_w > 1:
ratio_w = (in_w - 1.0) / (out_w - 1.0)
out = np.zeros((n, c, out_h, out_w))
for i in range(out_h):
in_i = int(ratio_h * i + 0.5)
for j in range(out_w):
in_j = int(ratio_w * j + 0.5)
out[:, :, i, j] = X[:, :, in_i, in_j]
return out.astype(X.dtype)
def bilinear_interp_np(input, out_h, out_w, out_size):
"""bilinear interpolation implement in shape [N, C, H, W]"""
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
......@@ -53,18 +77,29 @@ def bilinear_interp_np(input, out_h, out_w, out_size):
return out.astype(input.dtype)
class TestBilinearInterpOp(OpTest):
INTERPOLATE_FUNCS = {
'bilinear': bilinear_interp_np,
'nearest': nearest_neighbor_interp_np,
}
class TestInterpolateOp(OpTest):
def setUp(self):
self.out_size = None
self.init_test_case()
self.op_type = "bilinear_interp"
self.op_type = "interpolate"
input_np = np.random.random(self.input_shape).astype("float32")
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
self.out_size)
output_np = INTERPOLATE_FUNCS[self.interp_method](
input_np, self.out_h, self.out_w, self.out_size)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method
}
self.outputs = {'Out': output_np}
def test_check_output(self):
......@@ -74,90 +109,181 @@ class TestBilinearInterpOp(OpTest):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 4, 4]
self.out_h = 2
self.out_w = 2
self.out_size = np.array([3, 3]).astype("int32")
class TestCase1(TestBilinearInterpOp):
class TestBilinearInterpCase1(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
class TestCase2(TestBilinearInterpOp):
class TestBilinearInterpCase2(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
class TestCase3(TestBilinearInterpOp):
class TestBilinearInterpCase3(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
class TestCase4(TestBilinearInterpOp):
class TestBilinearInterpCase4(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.out_size = np.array([2, 2]).astype("int32")
class TestCase5(TestBilinearInterpOp):
class TestBilinearInterpCase5(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.out_size = np.array([11, 11]).astype("int32")
class TestCase6(TestBilinearInterpOp):
class TestBilinearInterpCase6(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
self.out_size = np.array([65, 129]).astype("int32")
class TestBilinearInterpOpUint8(OpTest):
# class TestBilinearInterpBigScale(TestInterpolateOp):
# def init_test_case(self):
# self.interp_method = 'bilinear'
# self.input_shape = [32, 16, 128, 64]
# self.out_h = 200
# self.out_w = 100
# self.out_size = np.array([201, 101]).astype('int32')
class TestInterpolateOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.init_test_case()
self.op_type = "bilinear_interp"
self.op_type = "interpolate"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
self.out_size)
output_np = INTERPOLATE_FUNCS[self.interp_method](
input_np, self.out_h, self.out_w, self.out_size)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 3, 9, 6]
self.out_h = 10
self.out_w = 9
class TestCase1Uint8(TestBilinearInterpOpUint8):
class TestBilinearInterpCase1Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 128, 64]
self.out_h = 120
self.out_w = 50
class TestBilinearInterpCase2Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.out_size = np.array([6, 15]).astype("int32")
class TestNearestNeighborInterpCase1(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
class TestNearestNeighborInterpCase2(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
class TestNearestNeighborInterpCase3(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
class TestNearestNeighborInterpCase4(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.out_size = np.array([2, 2]).astype("int32")
class TestNearestNeighborInterpCase5(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.out_size = np.array([11, 11]).astype("int32")
class TestNearestNeighborInterpCase6(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
self.out_size = np.array([65, 129]).astype("int32")
class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 128, 64]
self.out_h = 120
self.out_w = 50
class TestCase2Uint8(TestBilinearInterpOpUint8):
class TestNearestNeighborInterpCase2Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
......
......@@ -485,7 +485,7 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_resize_bilinear(self):
def test_resize_nearest(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 9, 6], dtype="float32")
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
def nearest_neighbor_interp_np(X, out_h, out_w, out_size=None):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
n, c, in_h, in_w = X.shape
ratio_h = ratio_w = 0.0
if out_h > 1:
ratio_h = (in_h - 1.0) / (out_h - 1.0)
if out_w > 1:
ratio_w = (in_w - 1.0) / (out_w - 1.0)
out = np.zeros((n, c, out_h, out_w))
for i in range(out_h):
in_i = int(round(ratio_h * i))
for j in range(out_w):
in_j = int(round(ratio_w * j))
out[:, :, i, j] = X[:, :, in_i, in_j]
return out.astype(X.dtype)
class TestBilinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.init_test_case()
self.op_type = "nearest_neighbor_interp"
input_np = np.random.random(self.input_shape).astype("float32")
output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w,
self.out_size)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.input_shape = [2, 3, 4, 4]
self.out_h = 2
self.out_w = 2
self.out_size = np.array([3, 3]).astype("int32")
class TestCase1(TestBilinearInterpOp):
def init_test_case(self):
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
class TestCase2(TestBilinearInterpOp):
def init_test_case(self):
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
class TestCase3(TestBilinearInterpOp):
def init_test_case(self):
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
class TestCase4(TestBilinearInterpOp):
def init_test_case(self):
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.out_size = np.array([2, 2]).astype("int32")
class TestCase5(TestBilinearInterpOp):
def init_test_case(self):
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.out_size = np.array([11, 11]).astype("int32")
class TestCase6(TestBilinearInterpOp):
def init_test_case(self):
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
self.out_size = np.array([65, 129]).astype("int32")
class TestBilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.init_test_case()
self.op_type = "nearest_neighbor_interp"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w,
self.out_size)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
def init_test_case(self):
self.input_shape = [1, 3, 9, 6]
self.out_h = 10
self.out_w = 9
class TestCase1Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.input_shape = [2, 3, 128, 64]
self.out_h = 120
self.out_w = 50
class TestCase2Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.out_size = np.array([6, 15]).astype("int32")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册