diff --git a/cmake/operators.cmake b/cmake/operators.cmake index c17e718f4279f24c85db8be1177e5b5e82b13e08..134c894392a604875780fcfc8ea93e06c9d48bdd 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -110,7 +110,7 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" -"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op") +"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "deformable_conv_op" "dgc_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index eb51a366f16971b03122d50a4dd2863bde215efe..8be10721598ebbbc03f94308317386c9c077cfa6 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -237,6 +237,7 @@ paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords= paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', '94e2819b7c9715ea71b62e9c78f36b29')) paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', '3126e3039e752ce26077f1efaca355c6')) paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'ccf6bb7912afd2818d24bc45461e807a')) +paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', 'c896b66265a60bd3c5510f66e6e02919')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', 'adf285346e23316097f7789b572491e9')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'dce69a78638da8f7ad80b1fc00ed2029')) paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6')) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 7a9c03ab5334a2c8c140743eafd23f3ac12fb7f9..b7abc68949c6e514626a7969086094b3f08830f4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -47,7 +47,8 @@ if (WITH_DISTRIBUTE) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) endif() -register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) +register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op + sync_batch_norm_op deformable_conv_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) if (WITH_GPU) # warpctc_op needs cudnn 7 above @@ -65,6 +66,8 @@ if (WITH_GPU) op_library(sync_batch_norm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") endif() + op_library(deformable_conv_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(deformable_conv);\n") else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() diff --git a/paddle/fluid/operators/deformable_conv_op.cc b/paddle/fluid/operators/deformable_conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..92a93dc74758a10bc7eca46c9e2e4e7a8fc52fe2 --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_op.cc @@ -0,0 +1,278 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/conv_op.h" + +namespace paddle { +namespace operators { +class DeformableConvOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "(Tensor) The input of deformable conv op. " + "The shape of input is " + "[N, channel_in, H, W]"); + AddInput("Offset", + "(Tensor) The input offset. " + "The shape of the offset is " + "[N, deformable_groups * kernel_w * kernel_h * 2, H, W"); + AddInput("Mask", + "(Tensor) The input mask. " + "The shape of the mask is " + "[N, deformable_groups * kernel_w * kernel_h, H, W]."); + AddInput("Filter", + "(Tensor) The Input Filter " + "The shape of the wight is " + "[num_filters, channel_in, kernel_h, kernel_w."); + AddOutput("Output", + "(Tensor) The output. " + "The shape of the output tensor is " + "[N, num_filters, out_height, out_width]]."); + AddAttr>("strides", + "(vector default:{1, 1}), the " + "strides(h_stride, w_stride) of " + "convolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", + "(vector default:{0,0}), the " + "paddings(h_pad, w_pad) of " + "convolution operator. ") + .SetDefault({0, 0}); + AddAttr>("dilations", + "(vector default:{1, 1}), the " + "dilations(h_dilation, w_dilation) of " + "convolution operator.") + .SetDefault({1, 1}); + AddAttr( + "groups", + "(int default:1), the groups number of the convolution operator. " + "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: " + "when group=2, the first half of the filters is only connected to the " + "first half of the input channels, while the second half of the " + "filters " + "is only connected to the second half of the input channels.") + .SetDefault(1); + AddAttr("deformable_groups", + "(int default:1), the number of the deformable groups.") + .SetDefault(1); + AddAttr("im2col_step", + "im2col maximum number of image per computation") + .SetDefault(64); + AddComment(R"DOC( +**Deformable Convolution Operator** + +Compute 2-D deformable convolution on 4-D input. + +Given input image x, output feature map y, the deformable convolution operation can be expressed as follow: + +$$ +y(p) = \\sum_{k=1}^{K}{w_k * x(p + p_k + \\Delta p_k) * \\Delta m_k} +$$ + +Where $$\\Delta p_k$$ and $$\Delta m_k$$ are the learnable offset and modulation scalar for the k-th location, respectively. + +Refer to 'Deformable ConvNets v2: More Deformable, Better Results +' + +Example: + Input: + Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Filter shape: $(C_{out}, C_{in}, H_f, W_f)$ + Offset shape: $(N, 2 * deformable_groups, * H_f * W_f, H_{out}, W_{out})$ + Mask shape: $(N, deformable_groups * H_f * W_f, H_{out}, W_{out})$ + Output: + Output shape: $(N, C_{out}, H_{out}, W_{out})$ + where $H_{out}, W_{out}$ must be equal to $H_{in}, W_{in}$ respectively. + Where +$$ + H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\ + W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1 +$$ +)DOC"); + } +}; + +class DeformableConvOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of DeformableConvOp " + "should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Offset"), + "Input(Offset) of DeformableConvOp " + "should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Mask"), + "Input(Mask) of DeformableConvOp " + "should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of DeformableConvOp " + "should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of DeformableConvOp " + "should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + auto offset_dims = ctx->GetInputDim("Offset"); + auto mask_dims = ctx->GetInputDim("Mask"); + + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + std::vector dilations = + ctx->Attrs().Get>("dilations"); + int groups = ctx->Attrs().Get("groups"); + int deformable_groups = ctx->Attrs().Get("deformable_groups"); + int im2col_step = ctx->Attrs().Get("im2col_step"); + + PADDLE_ENFORCE(in_dims.size() == 4, + "Conv input should be 4-D tensor, get %u", in_dims.size()); + PADDLE_ENFORCE_EQ( + in_dims.size(), filter_dims.size(), + "Conv input dimension and filter dimension should be the same."); + PADDLE_ENFORCE_EQ( + in_dims.size() - strides.size(), 2U, + "Conv input dimension and strides dimension should be consistent."); + PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), + "Conv paddings dimension and Conv strides dimension " + "should be the same."); + + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + filter_dims[0] % groups, 0, + "The number of output channels should be divided by groups."); + PADDLE_ENFORCE_EQ(filter_dims[0] % deformable_groups, 0, + "The number of output channels should be " + "divided by deformable groups."); + + if (in_dims[0] > im2col_step) { + PADDLE_ENFORCE_EQ( + in_dims[0] % im2col_step, 0U, + "Input batchsize must be smaller than or divide im2col_step"); + } + + for (size_t i = 0; i < strides.size(); ++i) { + PADDLE_ENFORCE_GT(strides[i], 0U, "stride %d size incorrect", i); + } + for (size_t i = 0; i < dilations.size(); ++i) { + PADDLE_ENFORCE_GT(dilations[i], 0U, "dilation %d size incorrect", i); + } + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); + } + PADDLE_ENFORCE_EQ(output_shape[1] % deformable_groups, 0U, + "output num_filter must divide deformable group size."); + PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2], + "output height must equal to offset map height."); + PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3], + "output width must equal to offset map width."); + PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U, + "offset filter must divide deformable group size."); + PADDLE_ENFORCE_EQ(offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]), + deformable_groups, + "offset filter must divide deformable group size."); + PADDLE_ENFORCE_EQ(output_shape[2], mask_dims[2], + "output height must equal to mask map height."); + PADDLE_ENFORCE_EQ(output_shape[3], mask_dims[3], + "output width must equal to mask map width."); + PADDLE_ENFORCE_EQ(mask_dims[1] % (filter_dims[2] * filter_dims[3]), 0U, + "mask filter must divide deformable group size."); + PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]), + deformable_groups, + "mask filter must divide deformable group size."); + + ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.device_context()); + } +}; + +class DeformableConvGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + + op->SetType("deformable_conv_grad"); + op->SetInput("Input", Input("Input")); + op->SetInput("Filter", Input("Filter")); + op->SetInput("Offset", Input("Offset")); + op->SetInput("Mask", Input("Mask")); + op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); + + op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); + op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter")); + op->SetOutput(framework::GradVarName("Offset"), InputGrad("Offset")); + op->SetOutput(framework::GradVarName("Mask"), InputGrad("Mask")); + + op->SetAttrMap(Attrs()); + return op; + } +}; + +class DeformableConvGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + auto offset_dims = ctx->GetInputDim("Offset"); + auto mask_dims = ctx->GetInputDim("Mask"); + + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")), + "the gradient of output(Out) must not be null"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } + if (ctx->HasOutput(framework::GradVarName("Offset"))) { + ctx->SetOutputDim(framework::GradVarName("Offset"), offset_dims); + } + if (ctx->HasOutput(framework::GradVarName("Mask"))) { + ctx->SetOutputDim(framework::GradVarName("Mask"), mask_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.device_context()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp, + ops::DeformableConvOpMaker, + ops::DeformableConvGradOpDescMaker); + +REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp); diff --git a/paddle/fluid/operators/deformable_conv_op.cu b/paddle/fluid/operators/deformable_conv_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d18d503d496eaedccfbf294f515ae7fbf7051ec --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_op.cu @@ -0,0 +1,743 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +__device__ T DmcnGetGradientWeight(T argmax_h, T argmax_w, const int h, + const int w, const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__global__ void ModulatedDeformableCol2imGpuKernel( + const int nthreads, const T* data_col, const T* data_offset, + const T* data_mask, const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int deformable_group, const int height_col, + const int width_col, T* grad_im) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t thread = index; thread < nthreads; thread += offset) { + const int j = (thread / width_col / height_col / batch_size) % kernel_w; + const int i = + (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + thread / width_col / height_col / batch_size / kernel_w / kernel_h; + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = thread % width_col; + int h_out = (thread / width_col) % height_col; + int b = (thread / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const T* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const T cur_top_grad = data_col[thread] * mask; + const int cur_h = static_cast(cur_inv_h_data); + const int cur_w = static_cast(cur_inv_w_data); + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = + DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, + cur_w + dx, height, width); + + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +inline void ModulatedDeformableCol2im( + const platform::DeviceContext& ctx, const T* data_col, const T* data_offset, + const T* data_mask, const std::vector im_shape, + const std::vector col_shape, + const std::vector kernel_shape, const std::vector pad, + const std::vector stride, const std::vector dilation, + const int deformable_group, T* grad_im) { + int channel_per_deformable_group = im_shape[0] / deformable_group; + int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + int blocks = NumBlocks(num_kernels); + int threads = kNumCUDAThreads; + + ModulatedDeformableCol2imGpuKernel<<< + blocks, threads, 0, + reinterpret_cast(ctx).stream()>>>( + num_kernels, data_col, data_offset, data_mask, im_shape[0], im_shape[1], + im_shape[2], kernel_shape[2], kernel_shape[3], pad[0], pad[1], stride[0], + stride[1], dilation[0], dilation[1], channel_per_deformable_group, + col_shape[1], deformable_group, col_shape[2], col_shape[3], grad_im); +} + +template +__device__ T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height, + const int width, const T* im_data, + const int data_width, const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + return weight; +} + +template +__global__ void ModulatedDeformableCol2imCoordGpuKernel( + const int nthreads, const T* data_col, const T* data_im, + const T* data_offset, const T* data_mask, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int offset_channels, const int deformable_group, const int height_col, + const int width_col, T* grad_offset, T* grad_mask) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + T val = 0, mval = 0; + const int w = i % width_col; + const int h = (i / width_col) % height_col; + const int c = (i / width_col / height_col) % offset_channels; + const int b = (i / width_col / height_col) / offset_channels; + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T* data_col_ptr = data_col + + deformable_group_index * + channel_per_deformable_group * batch_size * + width_col * height_col; + const T* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / + kernel_w * height * width; + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const T* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const int data_mask_hw_ptr = + (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } else { + mval += data_col_ptr[col_pos] * + DmcnIm2colBilinear(data_im_ptr + cnt * height * width, width, + height, width, inv_h, inv_w); + } + const T weight = DmcnGetCoordinateWeight( + inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + grad_offset[i] = val; + if (offset_c % 2 == 0) + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * + kernel_w + + offset_c / 2) * + height_col + + h) * + width_col + + w] = mval; + } +} + +template +inline void ModulatedDeformableCol2imCoord( + const platform::DeviceContext& ctx, const T* data_col, const T* data_im, + const T* data_offset, const T* data_mask, + const std::vector im_shape, const std::vector col_shape, + const std::vector kernel_shape, const std::vector paddings, + const std::vector strides, const std::vector dilations, + const int deformable_groups, T* grad_offset, T* grad_mask) { + int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * + col_shape[2] * col_shape[3] * deformable_groups; + int channel_per_deformable_group = col_shape[0] / deformable_groups; + int blocks = NumBlocks(num_kernels); + int threads = kNumCUDAThreads; + + ModulatedDeformableCol2imCoordGpuKernel<<< + blocks, threads, 0, + reinterpret_cast(ctx).stream()>>>( + num_kernels, data_col, data_im, data_offset, data_mask, im_shape[0], + im_shape[1], im_shape[2], kernel_shape[2], kernel_shape[3], paddings[0], + paddings[1], strides[0], strides[1], dilations[0], dilations[1], + channel_per_deformable_group, col_shape[1], + 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, + deformable_groups, col_shape[2], col_shape[3], grad_offset, grad_mask); +} + +template +__device__ T DmcnIm2colBilinear(const T* bottom_data, const int data_width, + const int height, const int width, T h, T w) { + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh, hw = 1 - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = bottom_data[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__global__ void ModulatedDeformableIm2colGpuKernel( + const int nthreads, const T* data_im, const T* data_offset, + const T* data_mask, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T* data_col) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + const int w_col = i % width_col; + const int h_col = (i / width_col) % height_col; + const int b_col = (i / width_col) / height_col % batch_size; + const int c_im = (i / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T* data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T* data_offset_ptr = + data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const T* data_mask_ptr = + data_mask + + (b_col * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + val = + DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +inline void ModulatedDeformableIm2col( + const platform::DeviceContext& ctx, const T* data_im, const T* data_offset, + const T* data_mask, const std::vector im_shape, + const std::vector col_shape, + const std::vector filter_shape, const std::vector paddings, + const std::vector strides, const std::vector dilations, + const int deformable_groups, T* data_col) { + int channel_per_deformable_group = im_shape[0] / deformable_groups; + int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + + int blocks = NumBlocks(num_kernels); + int threads = kNumCUDAThreads; + + ModulatedDeformableIm2colGpuKernel<<< + blocks, threads, 0, + reinterpret_cast(ctx).stream()>>>( + num_kernels, data_im, data_offset, data_mask, im_shape[1], im_shape[2], + filter_shape[2], filter_shape[3], paddings[0], paddings[1], strides[0], + strides[1], dilations[0], dilations[1], channel_per_deformable_group, + col_shape[1], im_shape[0], deformable_groups, col_shape[2], col_shape[3], + data_col); +} + +template +__global__ void FilterGradAddupGpuKernel(const int nthreads, const int n, + const int height, const int width, + const T* dweight_3d, T* filter_grad) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + filter_grad[i] = filter_grad[i] + dweight_3d[i]; + } +} + +template +class DeformableConvCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("Input"); + const Tensor offset = *ctx.Input("Offset"); + const Tensor mask = *ctx.Input("Mask"); + Tensor filter = *ctx.Input("Filter"); + Tensor* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.cuda_device_context(); + + const int groups = ctx.Attr("groups"); + const int deformable_groups = ctx.Attr("deformable_groups"); + const int im2col_step = ctx.Attr("im2col_step"); + const std::vector strides = ctx.Attr>("strides"); + const std::vector paddings = ctx.Attr>("paddings"); + const std::vector dilations = ctx.Attr>("dilations"); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec(framework::vectorize(output->dims())); + + // col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w} + std::vector col_buffer_shape_vec(filter_shape_vec.size()); + col_buffer_shape_vec[0] = + input->dims()[1] * filter.dims()[2] * filter.dims()[3]; + col_buffer_shape_vec[1] = im2col_step; + for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { + col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_buffer_shape_vec)); + std::vector output_buffer_shape_vec(1); + output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * + output_shape_vec[2] * output_shape_vec[3]; + framework::DDim output_shape(framework::make_ddim(output_buffer_shape_vec)); + Tensor col_buffer; + Tensor output_buffer; + col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); + output_buffer = + ctx.AllocateTmpTensor(output_shape, dev_ctx); + + int64_t M = output_shape_vec[1] / groups; + int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; + int64_t K = + input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups; + + Tensor weight_3d; + weight_3d.ShareDataWith(filter).Resize( + framework::make_ddim({groups, M, K})); + Tensor col_buffer_3d; + col_buffer_3d.ShareDataWith(col_buffer) + .Resize(framework::make_ddim({groups, K, N})); + Tensor output_4d; + output_4d.ShareDataWith(output_buffer) + .Resize(framework::make_ddim({batch_size / im2col_step, groups, M, N})); + output_4d.mutable_data(ctx.GetPlace()); + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); + std::vector input_shape_vec = framework::vectorize(input_shape); + + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset.numel() / offset.dims()[0]; + int input_mask_dim = mask.numel() / mask.dims()[0]; + + auto blas = math::GetBlas(dev_ctx); + + const T* input_ptr = input->data(); + const T* offset_ptr = offset.data(); + const T* mask_ptr = mask.data(); + col_buffer.mutable_data(ctx.GetPlace()); + T* col_buffer_ptr = col_buffer.data(); + + for (int i = 0; i < batch_size / im2col_step; ++i) { + ModulatedDeformableIm2col( + ctx.device_context(), input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations, + deformable_groups, col_buffer_ptr); + + Tensor output_3d = output_4d.Slice(i, i + 1).Resize( + framework::slice_ddim(output_4d.dims(), 1, output_4d.dims().size())); + for (int g = 0; g < groups; ++g) { + Tensor weight_3d_slice = + weight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + weight_3d.dims(), 1, weight_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + Tensor output_3d_slice = + output_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + output_3d.dims(), 1, output_3d.dims().size())); + + blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0), + &output_3d_slice, T(0.0)); + } + } + output->ShareDataWith(output_buffer) + .Resize(framework::make_ddim(output_shape_vec)); + } +}; + +template +class DeformableConvGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); + Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); + Tensor* mask_grad = ctx.Output(framework::GradVarName("Mask")); + + const Tensor* input = ctx.Input("Input"); + Tensor offset = *ctx.Input("Offset"); + Tensor mask = *ctx.Input("Mask"); + Tensor filter = *ctx.Input("Filter"); + if (!input_grad && !filter_grad && !offset_grad && !mask_grad) return; + + int groups = ctx.Attr("groups"); + int deformable_groups = ctx.Attr("deformable_groups"); + int im2col_step = ctx.Attr("im2col_step"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + auto& dev_ctx = ctx.cuda_device_context(); + const int batch_size = static_cast(input->dims()[0]); + + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); + std::vector input_shape_vec = framework::vectorize(input_shape); + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec( + framework::vectorize(output_grad->dims())); + + std::vector col_buffer_shape_vec(filter_shape_vec.size()); + col_buffer_shape_vec[0] = + input->dims()[1] * filter.dims()[2] * filter.dims()[3]; + col_buffer_shape_vec[1] = im2col_step; + for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { + col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_buffer_shape_vec)); + std::vector output_buffer_shape_vec(1); + output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * + output_shape_vec[2] * output_shape_vec[3]; + framework::DDim output_shape(framework::make_ddim(output_buffer_shape_vec)); + Tensor col_buffer; + Tensor output_buffer; + col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); + output_buffer = + ctx.AllocateTmpTensor(output_shape, dev_ctx); + + output_buffer.ShareDataWith(*output_grad); + + int64_t M = + input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3]; + int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; + int64_t K = output_shape_vec[1] / groups; + + framework::DDim weight_3d_shape = {groups, K, M}; + framework::DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, + N}; + framework::DDim col_buffer_3d_shape = {groups, M, N}; + framework::DDim filter_grad_shape = {groups, K, M}; + + Tensor weight_3d; + weight_3d.ShareDataWith(filter).Resize(weight_3d_shape); + Tensor out_grad_4d; + out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape); + Tensor col_buffer_3d; + col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); + + math::SetConstant set_zero; + auto blas = math::GetBlas(dev_ctx); + + col_buffer.mutable_data(ctx.GetPlace()); + col_buffer_3d.mutable_data(ctx.GetPlace()); + out_grad_4d.mutable_data(ctx.GetPlace()); + + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset.numel() / offset.dims()[0]; + int input_mask_dim = mask.numel() / mask.dims()[0]; + + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + filter_grad->Resize(filter_grad_shape); + set_zero(dev_ctx, filter_grad, static_cast(0)); + } + + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, input_grad, static_cast(0)); + } + + if (offset_grad && mask_grad) { + offset_grad->mutable_data(ctx.GetPlace()); + mask_grad->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, offset_grad, static_cast(0)); + set_zero(dev_ctx, mask_grad, static_cast(0)); + } + + for (int i = 0; i < batch_size / im2col_step; ++i) { + Tensor out_grad_3d = + out_grad_4d.Slice(i, i + 1).Resize(framework::slice_ddim( + out_grad_4d.dims(), 1, out_grad_4d.dims().size())); + for (int g = 0; g < groups; ++g) { + Tensor weight_3d_slice = + weight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + weight_3d.dims(), 1, weight_3d.dims().size())); + Tensor out_grad_3d_slice = + out_grad_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + + blas.MatMul(weight_3d_slice, true, out_grad_3d_slice, false, T(1.0), + &col_buffer_3d_slice, T(0.0)); + } + col_buffer.Resize(col_shape); + + T* col_buffer_ptr = col_buffer.data(); + const T* input_ptr = input->data(); + const T* offset_ptr = offset.data(); + const T* mask_ptr = mask.data(); + + if (mask_grad && offset_grad) { + T* offset_grad_ptr = offset_grad->data(); + T* mask_grad_ptr = mask_grad->data(); + ModulatedDeformableCol2imCoord( + ctx.device_context(), col_buffer_ptr, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, + dilations, deformable_groups, + offset_grad_ptr + i * im2col_step * input_offset_dim, + mask_grad_ptr + i * im2col_step * input_mask_dim); + } + if (input_grad) { + T* input_grad_ptr = input_grad->data(); + ModulatedDeformableCol2im( + ctx.device_context(), col_buffer_ptr, + offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, + dilations, deformable_groups, + input_grad_ptr + i * im2col_step * input_dim); + input_grad->Resize(input->dims()); + } + + ModulatedDeformableIm2col( + ctx.device_context(), input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations, + deformable_groups, col_buffer_ptr); + + col_buffer_3d.Resize(col_buffer_3d_shape); + + if (filter_grad) { + Tensor dweight_3d; + dweight_3d = + ctx.AllocateTmpTensor(filter_grad_shape, dev_ctx); + for (int g = 0; g < groups; ++g) { + Tensor out_grad_3d_slice = + out_grad_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + Tensor dweight_3d_slice = + dweight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + dweight_3d.dims(), 1, dweight_3d.dims().size())); + + blas.MatMul(out_grad_3d_slice, false, col_buffer_3d_slice, true, + T(1.0), &dweight_3d_slice, T(0.0)); + } + FilterGradAddupGpuKernel< + T><<>>( + dweight_3d.numel(), groups, K, M, dweight_3d.data(), + filter_grad->data()); + } + } + if (filter_grad) { + filter_grad->Resize(filter.dims()); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; + +REGISTER_OP_CUDA_KERNEL(deformable_conv, + ops::DeformableConvCUDAKernel); +REGISTER_OP_CUDA_KERNEL(deformable_conv_grad, + ops::DeformableConvGradCUDAKernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 144713e0396de1492e646cd59c0f67961e513baf..94af5775c83dbcd2dfa37e5569011c3ac9b2f2a7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -202,6 +202,7 @@ __all__ = [ 'continuous_value_model', 'where', 'sign', + 'deformable_conv', ] kIgnoreIndex = -100 @@ -11745,3 +11746,175 @@ def sign(x): helper.append_op(type='sign', inputs={'X': [x]}, outputs={'Out': [out]}) return out + + +def deformable_conv(input, + offset, + mask, + num_filters, + filter_size, + stride=1, + padding=0, + dilation=1, + groups=None, + deformable_groups=None, + im2col_step=None, + param_attr=None, + bias_attr=None, + name=None): + """ + **Deformable Convolution Layer** + + Compute 2-D deformable convolution on 4-D input. + Given input image x, output feature map y, the deformable convolution operation can be expressed as follow: + + .. math:: + + y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k} + + Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location, respectively. + Refer to `Deformable ConvNets v2: More Deformable, Better Results + `_ . + + Example: + - Input: + + Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` + + Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` + + Offset shape: :math:`(N, 2 * deformable\_groups * H_f * H_w, H_{in}, W_{in})` + + Mask shape: :math:`(N, deformable\_groups * H_f * H_w, H_{in}, W_{in})` + + - Output: + + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + + Where + + .. math:: + + H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 + + Args: + input (Variable): The input image with [N, C, H, W] format. + offset (Variable): The input coord offset of deformable convolution layer. + Mask (Variable): The input mask of deformable covolution layer. + num_filters(int): The number of filter. It is as same as the output + image channel. + filter_size (int|tuple|None): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + stride (int|tuple): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: stride = 1. + padding (int|tuple): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: padding = 0. + dilation (int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: dilation = 1. + groups (int): The groups number of the deformable conv layer. According to + grouped convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: groups=1. + deformable_groups (int): The number of deformable group partitions. + Default: deformable_groups = 1. + im2col_step (int): Maximum number of images per im2col computation; + The total batch size should be divisable by this value or smaller + than this value; if you face out of memory problem, you can try + to use a smaller value here. + Default: im2col_step = 64. + param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights + of deformable conv. If it is set to None or one attribute of ParamAttr, + deformable conv will create ParamAttr as param_attr. + If the Initializer of the param_attr is not set, the parameter is + initialized with :math:`Normal(0.0, std)`, and the + :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. + bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of + deformable conv layer. If it is set to False, no bias will be added + to the output units. If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + name (str|None): A name for this layer(optional). If set None, the layer + will be named automatically. Default: None + Returns: + Variable: The tensor variable storing the deformable convolution \ + result. + Raises: + ValueError: If the shapes of input, filter_size, stride, padding and + groups mismatch. + Examples: + .. code-block:: python + + data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32') + offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32') + mask = fluid.layers.data(name='mask', shape=[9, 32, 32], dtype='float32') + out = fluid.layers.deformable_conv(input=data, offset=offset, mask=mask, + num_filters=2, filter_size=3, padding=1) + """ + + num_channels = input.shape[1] + assert param_attr is not False, "param_attr should not be False here." + + helper = LayerHelper('deformable_conv', **locals()) + dtype = helper.input_dtype() + + if not isinstance(input, Variable): + raise TypeError("Input of deformable_conv must be Variable") + if not isinstance(offset, Variable): + raise TypeError("Input Offset of deformable_conv must be Variable") + if not isinstance(mask, Variable): + raise TypeError("Input Mask of deformable_conv must be Variable") + + if groups is None: + num_filter_channels = num_channels + else: + if num_channels % groups != 0: + raise ValueError("num_channels must be divisible by groups.") + num_filter_channels = num_channels // groups + + filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') + stride = utils.convert_to_list(stride, 2, 'stride') + padding = utils.convert_to_list(padding, 2, 'padding') + dilation = utils.convert_to_list(dilation, 2, 'dilation') + + input_shape = input.shape + filter_shape = [num_filters, int(num_filter_channels)] + filter_size + + def _get_default_param_initializer(): + filter_elem_num = filter_size[0] * filter_size[1] * num_channels + std = (2.0 / filter_elem_num)**0.5 + return Normal(0.0, std, 0) + + filter_param = helper.create_parameter( + attr=helper.param_attr, + shape=filter_shape, + dtype=dtype, + default_initializer=_get_default_param_initializer()) + + pre_bias = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type='deformable_conv', + inputs={ + 'Input': input, + 'Filter': filter_param, + 'Offset': offset, + 'Mask': mask, + }, + outputs={"Output": pre_bias}, + attrs={ + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'groups': groups, + 'deformable_groups': deformable_groups, + 'im2col_step': im2col_step, + }) + + output = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + return output diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py new file mode 100644 index 0000000000000000000000000000000000000000..aacb9ff447ef703379485e742cd231e37ba800c3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py @@ -0,0 +1,294 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid.core as core +from op_test import OpTest + + +def dmc_bilinear(data_im, height, width, h, w): + h_low = int(np.floor(h)) + w_low = int(np.floor(w)) + h_high = h_low + 1 + w_high = w_low + 1 + + lh = h - h_low + lw = w - w_low + hh = 1 - lh + hw = 1 - lw + + v1 = 0 + if h_low >= 0 and w_low >= 0: + v1 = data_im[h_low, w_low] + v2 = 0 + if h_low >= 0 and w_high <= width - 1: + v2 = data_im[h_low, w_high] + v3 = 0 + if h_high <= height - 1 and w_low >= 0: + v3 = data_im[h_high, w_low] + v4 = 0 + if h_high <= height - 1 and w_high <= width - 1: + v4 = data_im[h_high, w_high] + + w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + + return val + + +def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param): + in_n, in_c, in_h, in_w = input.shape + out_c, f_c, f_h, f_w = filter.shape + + assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w) + assert mask.shape == (in_n, f_h * f_w, in_h, in_w) + assert f_c * group == in_c + assert np.mod(out_c, group) == 0 + + stride, pad, dilation = conv_param['stride'], conv_param['pad'],\ + conv_param['dilation'] + out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0] + out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1] + assert out_h == in_h + assert out_w == in_w + + col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w)) + for n in range(in_n): + for c in range(in_c): + for h in range(out_h): + for w in range(out_w): + for kh in range(f_h): + for kw in range(f_w): + offset_h_table = \ + offset[n, ::2, h, w].reshape(f_h, f_w) + offset_w_table = \ + offset[n, 1::2, h, w].reshape(f_h, f_w) + mask_table = \ + mask[n, :, h, w].reshape(f_h, f_w) + offset_h = offset_h_table[kh, kw] + offset_w = offset_w_table[kh, kw] + val = 0 + im_h = h * stride[0] + kh * dilation[0] \ + + offset_h - pad[0] + im_w = w * stride[0] + kw * dilation[0] \ + + offset_w - pad[1] + if im_h > -1 and im_w > -1 and \ + im_h < in_h and im_w < in_h: + val = dmc_bilinear(input[n, c], in_h, in_w, + im_h, im_w) + val_out = val * mask_table[kh, kw] + col_buffer[n, c * f_h * f_w + kh * f_w + kw, h * + in_w + w] = val_out + + out = np.zeros((in_n, group, int(out_c // group), out_h * out_w)) + weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w) + col_buffer = col_buffer.reshape( + (in_n, group, int(in_c // group * f_h * f_w), in_h * in_w)) + for n in range(in_n): + for g in range(group): + out[n, g] = np.matmul(weight[g], col_buffer[n, g]) + out = out.reshape(in_n, out_c, out_h, out_w) + return out + + +class TestModulatedDeformableConvOp(OpTest): + def setUp(self): + self.op_type = "deformable_conv" + self.dtype = np.float32 + self.init_group() + self.init_dilation() + self.init_test_case() + + conv_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + offset = 10 * np.random.random(self.offset_size).astype(self.dtype) + mask = 10 * np.random.random(self.mask_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) + + output = dconv_im2col_gemm(input, offset, mask, filter, self.groups, + conv_param) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Offset': OpTest.np_dtype_to_fluid_dtype(offset), + 'Mask': OpTest.np_dtype_to_fluid_dtype(mask), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'deformable_groups': self.deformable_groups, + 'im2col_step': self.im2col_step, + 'dilations': self.dilations, + } + self.outputs = {'Output': output} + + def has_cuda(self): + return core.is_compiled_with_cuda() + + def test_check_output(self): + if self.has_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-5) + + def test_check_grad(self): + if self.has_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, {'Input', 'Offset', 'Mask', 'Filter'}, + 'Output', + max_relative_error=0.05) + + def test_check_grad_no_filter(self): + if self.has_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['Input', 'Offset', 'Mask'], + 'Output', + max_relative_error=0.1, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + if self.has_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['Filter', 'Offset', 'Mask'], + 'Output', + max_relative_error=0.1, + no_grad_set=set(['Input'])) + + def test_check_grad_no_offset_no_mask(self): + if self.has_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['Input', 'Filter'], + 'Output', + max_relative_error=0.1, + no_grad_set=set(['Offset', 'Mask'])) + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 4, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [4, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + +class TestWithStride(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [3, 3] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + +class TestWithDilation(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [1, 1] + self.input_size = [2, 3, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [2, 2] + + +class TestWith1x1(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 1, 1] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + +class TestWithGroup(TestModulatedDeformableConvOp): + def init_group(self): + self.groups = 2 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 2474125835fbf54316e26d272eec940fc380a448..4136bb7fef1054ee5698e05bb297e25e20317cf4 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1957,6 +1957,34 @@ class TestBook(LayerTest): self.assertIsNotNone(out) print(str(program)) + def test_deformable_conv(self): + if core.is_compiled_with_cuda(): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = layers.data( + name='input', + append_batch_size=False, + shape=[2, 3, 32, 32], + dtype="float32") + offset = layers.data( + name='offset', + append_batch_size=False, + shape=[2, 18, 32, 32], + dtype="float32") + mask = layers.data( + name='mask', + append_batch_size=False, + shape=[2, 9, 32, 32], + dtype="float32") + out = layers.deformable_conv( + input=input, + offset=offset, + mask=mask, + num_filters=2, + filter_size=3, + padding=1) + return (out) + if __name__ == '__main__': unittest.main()