提交 df4a3544 编写于 作者: D dengkaipeng

nearest neighbor interp add cuda kernel. test=develop

上级 97556119
...@@ -121,6 +121,7 @@ paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], vararg ...@@ -121,6 +121,7 @@ paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], vararg
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR')) paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR'))
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)) paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
......
...@@ -25,9 +25,9 @@ class NearestNeighborInterpOp : public framework::OperatorWithKernel { ...@@ -25,9 +25,9 @@ class NearestNeighborInterpOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of BilinearInterOp should not be null."); "Input(X) of NearestNeighborInterOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of BilinearInterOp should not be null."); "Output(Out) of NearestNeighborInterOp should not be null.");
auto dim_x = ctx->GetInputDim("X"); // NCHW format auto dim_x = ctx->GetInputDim("X"); // NCHW format
int out_h = ctx->Attrs().Get<int>("out_h"); int out_h = ctx->Attrs().Get<int>("out_h");
...@@ -64,8 +64,9 @@ class NearestNeighborInterpOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,8 +64,9 @@ class NearestNeighborInterpOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable(); .AsDispensable();
AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)"); AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)");
AddAttr<int>("out_h", "output height of bilinear interpolation op."); AddAttr<int>("out_h",
AddAttr<int>("out_w", "output width of bilinear interpolation op."); "output height of nearest neighbor interpolation op.");
AddAttr<int>("out_w", "output width of nearest neighbor interpolation op.");
AddComment(R"DOC( AddComment(R"DOC(
Nearest neighbor interpolation is to perform nearest neighbor interpolation Nearest neighbor interpolation is to perform nearest neighbor interpolation
in bot the 3rd dimention(in height direction) and the 4th dimention(in width in bot the 3rd dimention(in height direction) and the 4th dimention(in width
......
...@@ -15,17 +15,14 @@ ...@@ -15,17 +15,14 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor; using framework::Tensor;
template <typename T> template <typename T>
__global__ void KeBilinearInterpFw( __global__ void KeNearestNeighborInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w, const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w, const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratioW) { const size_t num_channels, const T ratio_h, const T ratio_w) {
int nthreads = output_h * output_w; int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) { if (tid < nthreads) {
...@@ -36,34 +33,22 @@ __global__ void KeBilinearInterpFw( ...@@ -36,34 +33,22 @@ __global__ void KeBilinearInterpFw(
int channel_id = out_id_w / out_img_size; int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w; int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = ratio_h * out_img_idy; int in_img_idy = static_cast<int>(round(ratio_h * out_img_idy));
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T h1lambda = ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w; int out_img_idx = tid % out_img_w;
int in_img_idx = ratioW * out_img_idx; int in_img_idx = static_cast<int>(round(ratio_w * out_img_idx));
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T w1lambda = ratioW * out_img_idx - in_img_idx; out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
T w2lambda = 1.f - w1lambda; in_img_idy * in_img_w + in_img_idx];
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
// bilinear interpolation
out[out_id_h * output_w + out_id_w] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
} }
} }
template <typename T> template <typename T>
__global__ void KeBilinearInterpBw( __global__ void KeNearestNeighborInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h, const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w, const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratioW) { const size_t num_channels, const T ratio_h, const T ratio_w) {
int nthreads = output_h * output_w; int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) { if (tid < nthreads) {
...@@ -74,25 +59,15 @@ __global__ void KeBilinearInterpBw( ...@@ -74,25 +59,15 @@ __global__ void KeBilinearInterpBw(
int channel_id = out_id_w / out_img_size; int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w; int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = ratio_h * out_img_idy; int in_img_idy = static_cast<int>(round(ratio_h * out_img_idy));
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T h1lambda = ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w; int out_img_idx = tid % out_img_w;
int in_img_idx = ratioW * out_img_idx; int in_img_idx = static_cast<int>(round(ratio_w * out_img_idx));
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T w1lambda = ratioW * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx]; in_img_idy * in_img_w + in_img_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w]; const T out_pos = out[out_id_h * output_w + out_id_w];
atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); atomicAdd(in_pos, out_pos);
atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]);
atomicAdd(&in_pos[h_id * in_img_w + w_id],
h1lambda * w1lambda * out_pos[0]);
} }
} }
...@@ -102,48 +77,49 @@ class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel<T> { ...@@ -102,48 +77,49 @@ class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto* input_t = ctx.Input<Tensor>("X"); // float tensor auto* input = ctx.Input<Tensor>("X"); // float tensor
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor auto* output = ctx.Output<Tensor>("Out"); // float tensor
auto* input = input_t->data<T>(); auto* input_data = input->data<T>();
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
auto out_dims = output_t->dims(); auto out_size = ctx.Input<Tensor>("OutSize");
auto out_size_t = ctx.Input<Tensor>("OutSize"); if (out_size != nullptr) {
if (out_size_t != nullptr) {
Tensor sizes; Tensor sizes;
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>(); auto size_data = sizes.data<int>();
out_h = size_data[0]; out_h = size_data[0];
out_w = size_data[1]; out_w = size_data[1];
} }
auto* output = output_t->mutable_data<T>(
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
int batch_size = input_t->dims()[0]; int n = input->dims()[0];
int channels = input_t->dims()[1]; int c = input->dims()[1];
int in_h = input_t->dims()[2]; int in_h = input->dims()[2];
int in_w = input_t->dims()[3]; int in_w = input->dims()[3];
auto* output_data =
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
int in_hw = in_h * in_w; int in_hw = in_h * in_w;
int out_hw = out_h * out_w; int out_hw = out_h * out_w;
int in_chw = channels * in_hw; int in_chw = c * in_hw;
int out_chw = channels * out_hw; int out_chw = c * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
memcpy(output, input, input_t->numel() * sizeof(T)); memcpy(output_data, input_data, input->numel() * sizeof(T));
} else { return;
int threadNum = batch_size * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
KeBilinearInterpFw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input, in_h, in_w, batch_size, in_chw, output, out_h, out_w,
batch_size, out_chw, channels, ratio_h, ratio_w);
} }
int threadNum = n * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
KeNearestNeighborInterpFw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w);
} }
}; };
...@@ -151,52 +127,53 @@ template <typename T> ...@@ -151,52 +127,53 @@ template <typename T>
class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> { class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X")); auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_output = d_output_t->data<T>(); auto* output_grad_data = output_grad->data<T>();
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace()); auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto& device_ctx = auto& device_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero; math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, d_input_t, static_cast<T>(0.0)); zero(device_ctx, input_grad, static_cast<T>(0.0));
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) { if (out_size != nullptr) {
Tensor sizes; Tensor sizes;
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>(); auto size_data = sizes.data<int>();
out_h = size_data[0]; out_h = size_data[0];
out_w = size_data[1]; out_w = size_data[1];
} }
int batch_size = d_input_t->dims()[0]; int n = input_grad->dims()[0];
int channels = d_input_t->dims()[1]; int c = input_grad->dims()[1];
int in_h = d_input_t->dims()[2]; int in_h = input_grad->dims()[2];
int in_w = d_input_t->dims()[3]; int in_w = input_grad->dims()[3];
int in_hw = in_h * in_w; int in_hw = in_h * in_w;
int out_hw = out_h * out_w; int out_hw = out_h * out_w;
int in_chw = channels * in_hw; int in_chw = c * in_hw;
int out_chw = channels * out_hw; int out_chw = c * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); memcpy(input_grad, output_grad, input_grad->numel() * sizeof(T));
} else { return;
int threadNum = batch_size * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
KeBilinearInterpBw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w,
batch_size, out_chw, channels, ratio_h, ratio_w);
} }
int threadNum = n * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
KeNearestNeighborInterpBw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w);
} }
}; };
...@@ -206,5 +183,5 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -206,5 +183,5 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp, REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp,
ops::NearestNeighborInterpOpCUDAKernel<float>); ops::NearestNeighborInterpOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(nearest_neighborinterp_grad, REGISTER_OP_CUDA_KERNEL(nearest_neighbor_interp_grad,
ops::NearestNeighborInterpGradOpCUDAKernel<float>); ops::NearestNeighborInterpGradOpCUDAKernel<float>);
...@@ -101,6 +101,7 @@ __all__ = [ ...@@ -101,6 +101,7 @@ __all__ = [
'image_resize', 'image_resize',
'image_resize_short', 'image_resize_short',
'resize_bilinear', 'resize_bilinear',
'resize_nearest',
'gather', 'gather',
'scatter', 'scatter',
'sequence_scatter', 'sequence_scatter',
...@@ -5584,6 +5585,7 @@ def image_resize(input, ...@@ -5584,6 +5585,7 @@ def image_resize(input,
Supporting resample methods: Supporting resample methods:
'BILINEAR' : Bilinear interpolation 'BILINEAR' : Bilinear interpolation
'NEAREST' : Nearest neighbor interpolation
Args: Args:
input (Variable): The input tensor of image resize layer, input (Variable): The input tensor of image resize layer,
...@@ -5610,13 +5612,17 @@ def image_resize(input, ...@@ -5610,13 +5612,17 @@ def image_resize(input,
out = fluid.layers.image_resize(input, out_shape=[12, 12]) out = fluid.layers.image_resize(input, out_shape=[12, 12])
""" """
resample_methods = {'BILINEAR': 'bilinear_interp'} resample_methods = {
'BILINEAR': 'bilinear_interp',
'NEAREST': 'nearest_neighbor_interp'
}
if resample not in resample_methods: if resample not in resample_methods:
raise ValueError( raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR' currently.") "The 'resample' of image_resize can only be 'BILINEAR' and 'NEAREST' currently."
)
if out_shape is None and scale is None: if out_shape is None and scale is None:
raise ValueError("One of out_shape and scale must not be None") raise ValueError("One of out_shape and scale must not be None")
helper = LayerHelper('bilinear_interp', **locals()) helper = LayerHelper(resample_methods[resample], **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
def _is_list_or_turple_(data): def _is_list_or_turple_(data):
...@@ -5672,6 +5678,29 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): ...@@ -5672,6 +5678,29 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
return image_resize(input, out_shape, scale, name, 'BILINEAR') return image_resize(input, out_shape, scale, name, 'BILINEAR')
@templatedoc(op_type="bilinear_interp")
def resize_nearest(input, out_shape=None, scale=None, name=None):
"""
${comment}
Args:
input(${x_type}): ${x_comment}.
out_shape(${out_size_type}): ${out_size_comment}.
scale(float|None): The multiplier for the input height or width. At
least one of out_shape or scale must be set. And out_shape has
a higher priority than scale. Default: None.
name(str|None): The output variable name.
Returns:
${out_comment}.
"""
return image_resize(input, out_shape, scale, name, 'NEAREST')
def image_resize_short(input, out_short_len, resample='BILINEAR'): def image_resize_short(input, out_short_len, resample='BILINEAR'):
""" """
Resize a batch of images. The short edge of input images will be Resize a batch of images. The short edge of input images will be
......
...@@ -485,6 +485,16 @@ class TestBook(unittest.TestCase): ...@@ -485,6 +485,16 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output) self.assertIsNotNone(output)
print(str(program)) print(str(program))
def test_resize_bilinear(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12])
self.assertIsNotNone(output)
output = layers.resize_nearest(x, scale=3)
self.assertIsNotNone(output)
print(str(program))
def test_polygon_box_transform(self): def test_polygon_box_transform(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册