From e10aa80f03e0046e25fa8faaea9d5af6a277e346 Mon Sep 17 00:00:00 2001 From: whs Date: Fri, 31 Aug 2018 17:03:48 +0800 Subject: [PATCH] Add pad2d op. (#12950) * Add pad2d op. * Add unitest and python api. * Fix cuda op kernel. * Fix python api. * Fix python api. * Update API.spec. * Fix python api --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/pad2d_op.cc | 584 ++++++++++++++++++ paddle/fluid/operators/pad2d_op.cu | 432 +++++++++++++ python/paddle/fluid/layers/nn.py | 93 ++- .../fluid/tests/unittests/test_layers.py | 14 + .../fluid/tests/unittests/test_pad2d_op.py | 102 +++ 6 files changed, 1224 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/pad2d_op.cc create mode 100644 paddle/fluid/operators/pad2d_op.cu create mode 100644 python/paddle/fluid/tests/unittests/test_pad2d_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index ed4e67879c7..b6ae930b715 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -170,6 +170,7 @@ paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], vara paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None)) paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)) +paddle.fluid.layers.pad2d ArgSpec(args=['input', 'paddings', 'mode', 'pad_value', 'data_format', 'name'], varargs=None, keywords=None, defaults=([0, 0, 0, 0], 'constant', 0.0, 'NCHW', None)) paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc new file mode 100644 index 00000000000..a706d05fd7c --- /dev/null +++ b/paddle/fluid/operators/pad2d_op.cc @@ -0,0 +1,584 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +void Pad2DConstNCHW(const T* in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, T value, + T* out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + out_data[out_h * out_width + out_w] = + (in_h < 0 || in_w < 0 || in_h >= in_height || in_w >= in_width) + ? value + : in_data[in_h * in_width + in_w]; + } + } + in_data += in_height * in_width; + out_data += out_height * out_width; + } + } +} + +template +void Pad2DConstNHWC(const T* in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, T value, + T* out_data) { + for (int n = 0; n < num; ++n) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + const int out_index = (out_h * out_width + out_w) * channels; + if (in_h < 0 || in_w < 0 || in_h >= in_height || in_w >= in_width) { + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = value; + } + } else { + const int in_index = (in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } + } + } + } + in_data += in_height * in_width * channels; + out_data += out_height * out_width * channels; + } +} + +template +void Pad2DReflectNCHW(const T* in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + in_h = std::max(in_h, -in_h); // reflect by 0 + in_h = + std::min(in_h, 2 * in_height - in_h - 2); // reflect by in_height + in_w = std::max(in_w, -in_w); // reflect by 0 + in_w = + std::min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + out_data[out_h * out_width + out_w] = in_data[in_h * in_width + in_w]; + } + } + in_data += in_height * in_width; + out_data += out_height * out_width; + } + } +} + +template +void Pad2DReflectNHWC(const T* in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + const int out_index = (out_h * out_width + out_w) * channels; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + in_h = std::max(in_h, -in_h); + in_h = std::min(in_h, 2 * in_height - in_h - 2); + in_w = std::max(in_w, -in_w); + in_w = std::min(in_w, 2 * in_width - in_w - 2); + const int in_index = (in_h * in_width + in_w) * channels; + + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } + } + } + in_data += in_height * in_width * channels; + out_data += out_height * out_width * channels; + } +} + +template +void Pad2DEdgeNCHW(const T* in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, const int pad_top, + const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + out_data[out_h * out_width + out_w] = in_data[in_h * in_width + in_w]; + } + } + in_data += in_height * in_width; + out_data += out_height * out_width; + } + } +} + +template +void Pad2DEdgeNHWC(const T* in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, const int pad_top, + const int pad_left, T* out_data) { + for (int n = 0; n < num; ++n) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + const int out_index = (out_h * out_width + out_w) * channels; + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + const int in_index = (in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + out_data[out_index + c] = in_data[in_index + c]; + } + } + } + in_data += in_height * in_width * channels; + out_data += out_height * out_width * channels; + } +} + +template +void Pad2DGradConstNCHW(T* d_in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + if (!(in_h < 0 || in_w < 0 || in_h >= in_height || + in_w >= in_width)) { + d_in_data[in_h * in_width + in_w] = + d_out_data[out_h * out_width + out_w]; + } + } + } + d_in_data += in_height * in_width; + d_out_data += out_height * out_width; + } + } +} + +template +void Pad2DGradConstNHWC(T* d_in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + const int out_index = (out_h * out_width + out_w) * channels; + if (!(in_h < 0 || in_w < 0 || in_h >= in_height || in_w >= in_width)) { + const int in_index = (in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] = d_out_data[out_index + c]; + } + } + } + } + d_in_data += in_height * in_width * channels; + d_out_data += out_height * out_width * channels; + } +} + +template +void Pad2DGradReflectNCHW(T* d_in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + in_h = std::max(in_h, -in_h); // reflect over 0 + in_h = std::min(in_h, + 2 * in_height - in_h - 2); // reflect over in_height + in_w = std::max(in_w, -in_w); // reflect over 0 + in_w = + std::min(in_w, 2 * in_width - in_w - 2); // reflect over in_width + d_in_data[in_h * in_width + in_w] += + d_out_data[out_h * out_width + out_w]; + } + } + d_in_data += in_height * in_width; + d_out_data += out_height * out_width; + } + } +} + +template +void Pad2DGradReflectNHWC(T* d_in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + const int out_index = (out_h * out_width + out_w) * channels; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + in_h = std::max(in_h, -in_h); + in_h = std::min(in_h, 2 * in_height - in_h - 2); + in_w = std::max(in_w, -in_w); + in_w = std::min(in_w, 2 * in_width - in_w - 2); + const int in_index = (in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; + } + } + } + d_in_data += in_height * in_width * channels; + d_out_data += out_height * out_width * channels; + } +} + +template +void Pad2DGradEdgeNCHW(T* d_in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + d_in_data[in_h * in_width + in_w] += + d_out_data[out_h * out_width + out_w]; + } + } + d_in_data += in_height * in_width; + d_out_data += out_height * out_width; + } + } +} + +template +void Pad2DGradEdgeNHWC(T* d_in_data, const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + for (int n = 0; n < num; ++n) { + for (int out_h = 0; out_h < out_height; ++out_h) { + for (int out_w = 0; out_w < out_width; ++out_w) { + const int out_index = (out_h * out_width + out_w) * channels; + int in_h = std::min(in_height - 1, std::max(out_h - pad_top, 0)); + int in_w = std::min(in_width - 1, std::max(out_w - pad_left, 0)); + const int in_index = (in_h * in_width + in_w) * channels; + for (int c = 0; c < channels; ++c) { + d_in_data[in_index + c] += d_out_data[out_index + c]; + } + } + } + d_in_data += in_height * in_width * channels; + d_out_data += out_height * out_width * channels; + } +} + +template +class Pad2dCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto pads = context.Attr>("paddings"); + auto mode = context.Attr("mode"); + auto data_format = context.Attr("data_format"); + T value = context.Attr("pad_value"); + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + auto in_dims = x->dims(); + auto out_dims = out->dims(); + const T* in_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + const int pad_top = pads[0]; + const int pad_left = pads[2]; + const int num = in_dims[0]; + if (data_format == "NCHW") { + const int channels = in_dims[1]; + const int in_height = in_dims[2]; + const int in_width = in_dims[3]; + const int out_height = out_dims[2]; + const int out_width = out_dims[3]; + if (mode == "reflect") { + Pad2DReflectNCHW(in_data, num, channels, in_height, in_width, + out_height, out_width, pad_top, pad_left, out_data); + } else if (mode == "edge") { + Pad2DEdgeNCHW(in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, out_data); + } else { + Pad2DConstNCHW(in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, value, out_data); + } + } else { + const int channels = in_dims[3]; + const int in_height = in_dims[1]; + const int in_width = in_dims[2]; + const int out_height = out_dims[1]; + const int out_width = out_dims[2]; + if (mode == "reflect") { + Pad2DReflectNHWC(in_data, num, channels, in_height, in_width, + out_height, out_width, pad_top, pad_left, out_data); + } else if (mode == "edge") { + Pad2DEdgeNHWC(in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, out_data); + } else { + Pad2DConstNHWC(in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, value, out_data); + } + } + } +}; + +template +class Pad2dGradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto pads = context.Attr>("paddings"); + auto mode = context.Attr("mode"); + auto data_format = context.Attr("data_format"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_in = context.Output(framework::GradVarName("X")); + auto d_in_dims = d_in->dims(); + auto d_out_dims = d_out->dims(); + const T* d_out_data = d_out->data(); + T* d_in_data = d_in->mutable_data(context.GetPlace()); + math::SetConstant set_zero; + set_zero(context.template device_context(), + d_in, static_cast(0)); + const int pad_top = pads[0]; + const int pad_left = pads[2]; + const int num = d_in_dims[0]; + if (data_format == "NCHW") { + const int channels = d_in_dims[1]; + const int in_height = d_in_dims[2]; + const int in_width = d_in_dims[3]; + const int out_height = d_out_dims[2]; + const int out_width = d_out_dims[3]; + if (mode == "reflect") { + Pad2DGradReflectNCHW(d_in_data, num, channels, in_height, in_width, + out_height, out_width, pad_top, pad_left, + d_out_data); + } else if (mode == "edge") { + Pad2DGradEdgeNCHW(d_in_data, num, channels, in_height, in_width, + out_height, out_width, pad_top, pad_left, d_out_data); + } else { + Pad2DGradConstNCHW(d_in_data, num, channels, in_height, in_width, + out_height, out_width, pad_top, pad_left, + d_out_data); + } + } else { + const int channels = d_in_dims[3]; + const int in_height = d_in_dims[1]; + const int in_width = d_in_dims[2]; + const int out_height = d_out_dims[1]; + const int out_width = d_out_dims[2]; + if (mode == "reflect") { + Pad2DGradReflectNHWC(d_in_data, num, channels, in_height, in_width, + out_height, out_width, pad_top, pad_left, + d_out_data); + } else if (mode == "edge") { + Pad2DGradEdgeNHWC(d_in_data, num, channels, in_height, in_width, + out_height, out_width, pad_top, pad_left, d_out_data); + } else { + Pad2DGradConstNHWC(d_in_data, num, channels, in_height, in_width, + out_height, out_width, pad_top, pad_left, + d_out_data); + } + } + } +}; + +class Pad2dOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of Pad2dOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of Pad2dOp should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto paddings = ctx->Attrs().Get>("paddings"); + PADDLE_ENFORCE_EQ(x_dim.size(), 4, + "Size of paddings should be equal to 4."); + std::vector out_dims(x_dim.size()); + + auto data_format = ctx->Attrs().Get("data_format"); + out_dims[0] = x_dim[0]; + if (data_format == "NCHW") { + out_dims[1] = x_dim[1]; + out_dims[2] = x_dim[2] + paddings[0] + paddings[1]; // height + out_dims[3] = x_dim[3] + paddings[2] + paddings[3]; // width + } else { // NHWC + out_dims[3] = x_dim[3]; + out_dims[1] = x_dim[1] + paddings[0] + paddings[1]; + out_dims[2] = x_dim[2] + paddings[2] + paddings[3]; + } + + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + if (out_dims[0] == x_dim[0]) { + // Only pass LoD when the first dimension is equal between + // output and input. + ctx->ShareLoD("X", /*->*/ "Out"); + } + } +}; + +class Pad2dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input of pad2d op. " + "The input should be a 4-D tensor with formate NCHW or NHWC."); + AddOutput("Out", + "The output of pad2d op. " + "A tensor with the same shape as X."); + AddAttr>( + "paddings", + "(vector) " + "A list to describe the padding rules." + "paddings=[0, 1, 2, 3] means " + "padding 0 row to top, 1 row to bottom, 2 columns to left " + "and 3 columns to right. Size of paddings must be 4."); + AddAttr("pad_value", + "(float, default 0.0) " + "The value to fill the padded areas in constant mode.") + .SetDefault(0.0f); + AddAttr("mode", + "(float, default constant) " + "Three modes: constant(default), reflect, edge.") + .SetDefault("constant"); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the input data.") + .SetDefault("NCHW"); + AddComment(R"DOC( +Pad2d Operator. +Pad 2-d images accordding to 'paddings' and 'mode'. +If mode is 'reflect', paddings[0] and paddings[1] must be no greater +than height-1. And the width dimension has the same condition. + +Given that X is a channel of image from input: + +X = [[1, 2, 3], + [4, 5, 6]] + +Case 0: + +paddings = [0, 1, 2, 3], +mode = 'constant' +pad_value = 0 + +Out = [[0, 0, 1, 2, 3, 0, 0, 0] + [0, 0, 4, 5, 6, 0, 0, 0] + [0, 0, 0, 0, 0, 0, 0, 0]] + +Case 1: + +paddings = [0, 1, 2, 1], +mode = 'reflect' + +Out = [[3, 2, 1, 2, 3, 2] + [6, 5, 4, 5, 6, 5] + [3, 2, 1, 2, 3, 2]] + +Case 2: + +paddings = [0, 1, 2, 1], +mode = 'edge' + +Out = [[1, 1, 1, 2, 3, 3] + [4, 4, 4, 5, 6, 6] + [4, 4, 4, 5, 6, 6]] +)DOC"); + } +}; + +class Pad2dOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } +}; + +class Pad2dOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* bind = new framework::OpDesc(); + bind->SetInput("X", Input("X")); + bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); + bind->SetAttrMap(Attrs()); + bind->SetType("pad2d_grad"); + return std::unique_ptr(bind); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker, + ops::Pad2dOpGradMaker); +REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad); +REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel); +REGISTER_OP_CPU_KERNEL(pad2d_grad, ops::Pad2dGradCPUKernel); diff --git a/paddle/fluid/operators/pad2d_op.cu b/paddle/fluid/operators/pad2d_op.cu new file mode 100644 index 00000000000..9ba0ddbd84a --- /dev/null +++ b/paddle/fluid/operators/pad2d_op.cu @@ -0,0 +1,432 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +using framework::Tensor; + +template +__global__ void Pad2DConstNCHW(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, T value, + T* out_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + out_data[index] = + (in_h < 0 || in_w < 0 || in_h >= in_height || in_w >= in_width) + ? value + : in_data[(nc * in_height + in_h) * in_width + in_w]; + } +} + +template +__global__ void Pad2DConstNHWC(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, T value, + T* out_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int in_h = out_h - pad_top; + const int in_w = out_w - pad_left; + out_data[index] = + (in_h < 0 || in_w < 0 || in_h >= in_height || in_w >= in_width) + ? value + : in_data[((n * in_height + in_h) * in_width + in_w) * channels + + c]; + } +} + +template +__global__ void Pad2DReflectNCHW(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + T* out_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + in_h = max(in_h, -in_h); // reflect by 0 + in_h = min(in_h, 2 * in_height - in_h - 2); // reflect by in_height + in_w = max(in_w, -in_w); // reflect by 0 + in_w = min(in_w, 2 * in_width - in_w - 2); // reflect by in_width + out_data[index] = in_data[(nc * in_height + in_h) * in_width + in_w]; + } +} + +template +__global__ void Pad2DReflectNHWC(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + T* out_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + in_h = max(in_h, -in_h); + in_h = min(in_h, 2 * in_height - in_h - 2); + in_w = max(in_w, -in_w); + in_w = min(in_w, 2 * in_width - in_w - 2); + out_data[index] = + in_data[((n * in_height + in_h) * in_width + in_w) * channels + c]; + } +} + +template +__global__ void Pad2DEdgeNCHW(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + T* out_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int nc = index / out_width; + const int out_w = index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + out_data[index] = in_data[(nc * in_height + in_h) * in_width + in_w]; + } +} + +template +__global__ void Pad2DEdgeNHWC(const int nthreads, const T* in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + T* out_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int n = index / channels; + const int c = index % channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + out_data[index] = + in_data[((n * in_height + in_h) * in_width + in_w) * channels + c]; + } +} + +template +__global__ void Pad2DGradConstNCHW(const int in_size, T* d_in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_1D_KERNEL_LOOP(in_index, in_size) { + int nc = in_index / in_width; + const int out_w = in_index % in_width + pad_left; + const int out_h = nc % in_height + pad_top; + nc /= in_height; + d_in_data[in_index] = + d_out_data[(nc * out_height + out_h) * out_width + out_w]; + } +} + +template +__global__ void Pad2DGradConstNHWC(const int in_size, T* d_in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_1D_KERNEL_LOOP(in_index, in_size) { + int n = in_index / channels; + const int c = in_index % channels; + const int out_w = n % in_width + pad_left; + n /= in_width; + const int out_h = n % in_height + pad_top; + n /= in_height; + d_in_data[in_index] = + d_out_data[((n * out_height + out_h) * out_width + out_w) * channels + + c]; + } +} + +template +__global__ void Pad2DGradReflectNCHW(const int out_size, T* d_in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_1D_KERNEL_LOOP(out_index, out_size) { + int nc = out_index / out_width; + const int out_w = out_index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + in_h = max(in_h, -in_h); + in_w = max(in_w, -in_w); + in_h = min(in_h, 2 * in_height - in_h - 2); + in_w = min(in_w, 2 * in_width - in_w - 2); + atomicAdd(&d_in_data[(nc * in_height + in_h) * in_width + in_w], + d_out_data[out_index]); + } +} + +template +__global__ void Pad2DGradReflectNHWC(const int out_size, T* d_in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_1D_KERNEL_LOOP(out_index, out_size) { + const int c = out_index % channels; + int n = out_index / channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + int in_h = out_h - pad_top; + int in_w = out_w - pad_left; + in_h = max(in_h, -in_h); + in_w = max(in_w, -in_w); + in_h = min(in_h, in_height * 2 - in_h - 2); + in_w = min(in_w, in_width * 2 - in_w - 2); + atomicAdd( + &d_in_data[((n * in_height + in_h) * in_width + in_w) * channels + c], + d_out_data[out_index]); + } +} + +template +__global__ void Pad2DGradEdgeNCHW(const int out_size, T* d_in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_1D_KERNEL_LOOP(out_index, out_size) { + int nc = out_index / out_width; + const int out_w = out_index % out_width; + const int out_h = nc % out_height; + nc /= out_height; + const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + atomicAdd(&d_in_data[(nc * in_height + in_h) * in_width + in_w], + d_out_data[out_index]); + } +} + +template +__global__ void Pad2DGradEdgeNHWC(const int out_size, T* d_in_data, + const int num, const int channels, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_top, const int pad_left, + const T* d_out_data) { + CUDA_1D_KERNEL_LOOP(out_index, out_size) { + const int c = out_index % channels; + int n = out_index / channels; + const int out_w = n % out_width; + n /= out_width; + const int out_h = n % out_height; + n /= out_height; + const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); + const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); + atomicAdd( + &d_in_data[((n * in_height + in_h) * in_width + in_w) * channels + c], + d_out_data[out_index]); + } +} + +template +class Pad2dCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto pads = context.Attr>("paddings"); + auto mode = context.Attr("mode"); + auto data_format = context.Attr("data_format"); + T value = context.Attr("pad_value"); + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + auto in_dims = x->dims(); + auto out_dims = out->dims(); + const T* in_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + const int pad_top = pads[0]; + const int pad_left = pads[2]; + const int num = in_dims[0]; + + auto stream = context.cuda_device_context().stream(); + int block = PADDLE_CUDA_NUM_THREADS; + const int out_size = out->numel(); + int grid = (out_size + block - 1) / block; + + if (data_format == "NCHW") { + const int channels = in_dims[1]; + const int in_height = in_dims[2]; + const int in_width = in_dims[3]; + const int out_height = out_dims[2]; + const int out_width = out_dims[3]; + if (mode == "reflect") { + Pad2DReflectNCHW<<>>( + out_size, in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, out_data); + } else if (mode == "edge") { + Pad2DEdgeNCHW<<>>( + out_size, in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, out_data); + } else { + Pad2DConstNCHW<<>>( + out_size, in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, value, out_data); + } + } else { + const int channels = in_dims[3]; + const int in_height = in_dims[1]; + const int in_width = in_dims[2]; + const int out_height = out_dims[1]; + const int out_width = out_dims[2]; + if (mode == "reflect") { + Pad2DReflectNHWC<<>>( + out_size, in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, out_data); + } else if (mode == "edge") { + Pad2DEdgeNHWC<<>>( + out_size, in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, out_data); + } else { + Pad2DConstNHWC<<>>( + out_size, in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, value, out_data); + } + } + } +}; + +template +class Pad2dGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto pads = context.Attr>("paddings"); + auto mode = context.Attr("mode"); + auto data_format = context.Attr("data_format"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_in = context.Output(framework::GradVarName("X")); + auto d_in_dims = d_in->dims(); + auto d_out_dims = d_out->dims(); + const T* d_out_data = d_out->data(); + T* d_in_data = d_in->mutable_data(context.GetPlace()); + + math::SetConstant set_zero; + set_zero(context.template device_context(), + d_in, static_cast(0)); + + const int pad_top = pads[0]; + const int pad_left = pads[2]; + const int num = d_in_dims[0]; + + auto stream = context.cuda_device_context().stream(); + int block = PADDLE_CUDA_NUM_THREADS; + const int out_size = d_out->numel(); + const int in_size = d_in->numel(); + int grid = (out_size + block - 1) / block; + + if (data_format == "NCHW") { + const int channels = d_in_dims[1]; + const int in_height = d_in_dims[2]; + const int in_width = d_in_dims[3]; + const int out_height = d_out_dims[2]; + const int out_width = d_out_dims[3]; + if (mode == "reflect") { + Pad2DGradReflectNCHW<<>>( + out_size, d_in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, d_out_data); + } else if (mode == "edge") { + Pad2DGradEdgeNCHW<<>>( + out_size, d_in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, d_out_data); + } else { + grid = (in_size + block - 1) / block; + Pad2DGradConstNCHW<<>>( + in_size, d_in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, d_out_data); + } + } else { + const int channels = d_in_dims[3]; + const int in_height = d_in_dims[1]; + const int in_width = d_in_dims[2]; + const int out_height = d_out_dims[1]; + const int out_width = d_out_dims[2]; + if (mode == "reflect") { + Pad2DGradReflectNHWC<<>>( + out_size, d_in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, d_out_data); + } else if (mode == "edge") { + Pad2DGradEdgeNHWC<<>>( + out_size, d_in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, d_out_data); + } else { + grid = (in_size + block - 1) / block; + Pad2DGradConstNHWC<<>>( + in_size, d_in_data, num, channels, in_height, in_width, out_height, + out_width, pad_top, pad_left, d_out_data); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel); +REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8ef7444a1a3..0ecfc958a3b 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -109,6 +109,7 @@ __all__ = [ 'flatten', 'sequence_mask', 'stack', + 'pad2d', 'unstack', ] @@ -5614,6 +5615,94 @@ def rank_loss(label, left, right, name=None): return out +def pad2d(input, + paddings=[0, 0, 0, 0], + mode='constant', + pad_value=0.0, + data_format="NCHW", + name=None): + """ + Pad 2-d images accordding to 'paddings' and 'mode'. + If mode is 'reflect', paddings[0] and paddings[1] must be no greater + than height-1. And the width dimension has the same condition. + + Example: + + Given that X is a channel of image from input: + + X = [[1, 2, 3], + [4, 5, 6]] + + Case 0: + + paddings = [0, 1, 2, 3], + mode = 'constant' + pad_value = 0 + + Out = [[0, 0, 1, 2, 3, 0, 0, 0] + [0, 0, 4, 5, 6, 0, 0, 0] + [0, 0, 0, 0, 0, 0, 0, 0]] + + Case 1: + + paddings = [0, 1, 2, 1], + mode = 'reflect' + + Out = [[3, 2, 1, 2, 3, 2] + [6, 5, 4, 5, 6, 5] + [3, 2, 1, 2, 3, 2]] + + Case 2: + + paddings = [0, 1, 2, 1], + mode = 'edge' + + Out = [[1, 1, 1, 2, 3, 3] + [4, 4, 4, 5, 6, 6] + [4, 4, 4, 5, 6, 6]] + + + Args: + input (Variable): The input image with [N, C, H, W] format or [N, H, W, C] format. + paddings (tuple|list): The padding size. If padding is a tuple, it must + contain four integers, (padding_top, padding_bottom, padding_left, padding_right). + Default: padding = [0, 0, 0, 0]. + mode (str): Three modes: constant(default), reflect, edge. Default: constant + pad_value (float32): The value to fill the padded areas in constant mode. Default: 0 + data_format (str): An optional string from: "NHWC", "NCHW". Specify the data format of + the input data. + Default: "NCHW" + name (str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The tensor variable padded accordding to paddings and mode. + + + Examples: + .. code-block:: python + + data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32') + result = fluid.layers.pad2d(input=data, padding=[1,2,3,4], mode='reflect') + """ + + helper = LayerHelper('pad2d', **locals()) + dtype = helper.input_dtype(input_param_name='input') + out = helper.create_tmp_variable(dtype) + helper.append_op( + type='pad2d', + inputs={'X': input}, + outputs={"Out": out}, + attrs={ + 'paddings': paddings, + 'mode': mode, + 'pad_value': pad_value, + 'data_frmat': data_format + }) + + return out + + def prelu(x, mode, param_attr=None, name=None): """ Equation: @@ -5628,8 +5717,8 @@ def prelu(x, mode, param_attr=None, name=None): all: all elements share same weight channel:elements in a channel share same weight element:each element has a weight - name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. Returns: Variable: The output tensor with the same shape as input. diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index f2fccd5d767..ecdf32524af 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -521,6 +521,20 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_pad2d(self): + program = Program() + with program_guard(program): + input = layers.data( + name="input", shape=[3, 100, 100], dtype="float32") + out = layers.pad2d( + input, + paddings=[1, 2, 3, 4], + mode='reflect', + data_format='NCHW', + name="shape") + self.assertIsNotNone(out) + print(str(program)) + def test_prelu(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_pad2d_op.py b/python/paddle/fluid/tests/unittests/test_pad2d_op.py new file mode 100644 index 00000000000..728b8c181a4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pad2d_op.py @@ -0,0 +1,102 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestPad2dOp(OpTest): + def setUp(self): + self.pad_value = 0.0 + self.initTestCase() + self.op_type = "pad2d" + self.inputs = {'X': np.random.random(self.shape).astype("float32"), } + self.attrs = {} + self.attrs['paddings'] = np.array(self.paddings).flatten() + self.attrs['pad_value'] = self.pad_value + self.attrs['mode'] = self.mode + self.attrs['data_format'] = self.data_format + if self.data_format == "NCHW": + paddings = [(0, 0), (0, 0), (self.paddings[0], self.paddings[1]), + (self.paddings[2], self.paddings[3])] + else: + paddings = [(0, 0), (self.paddings[0], self.paddings[1]), + (self.paddings[2], self.paddings[3]), (0, 0)] + if self.mode == "constant": + out = np.pad(self.inputs['X'], + paddings, + mode=self.mode, + constant_values=self.pad_value) + else: + out = np.pad(self.inputs['X'], paddings, mode=self.mode) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', max_relative_error=0.006) + + def initTestCase(self): + self.shape = (2, 3, 4, 4) + self.paddings = [0, 1, 2, 3] + self.mode = "constant" + self.data_format = "NCHW" + self.pad_value = 0.0 + + +class TestCase1(TestPad2dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 4) + self.paddings = [0, 1, 2, 3] + self.mode = "reflect" + self.data_format = "NCHW" + + +class TestCase2(TestPad2dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 4) + self.paddings = [0, 1, 2, 3] + self.mode = "edge" + self.data_format = "NCHW" + + +class TestCase3(TestPad2dOp): + def initTestCase(self): + self.shape = (2, 4, 4, 2) + self.paddings = [0, 1, 2, 3] + self.mode = "reflect" + self.data_format = "NHWC" + + +class TestCase4(TestPad2dOp): + def initTestCase(self): + self.shape = (2, 4, 4, 2) + self.paddings = [0, 1, 2, 3] + self.mode = "edge" + self.data_format = "NHWC" + + +class TestCase5(TestPad2dOp): + def initTestCase(self): + self.shape = (2, 4, 4, 2) + self.paddings = [0, 1, 2, 3] + self.mode = "constant" + self.pad_value = 1.2 + self.data_format = "NHWC" + + +if __name__ == '__main__': + unittest.main() -- GitLab