diff --git a/paddle/fluid/operators/deformable_conv_op_mlu.cc b/paddle/fluid/operators/deformable_conv_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..a280d7ef1ad1f602f34a2b62af3c4408651fadad --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_op_mlu.cc @@ -0,0 +1,248 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class DeformableConvMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* offset = ctx.Input("Offset"); + auto* mask = ctx.Input("Mask"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + + const int groups = ctx.Attr("groups"); + const int deformable_groups = ctx.Attr("deformable_groups"); + const int im2col_step = ctx.Attr("im2col_step"); + const std::vector strides = ctx.Attr>("strides"); + const std::vector paddings = ctx.Attr>("paddings"); + const std::vector dilations = ctx.Attr>("dilations"); + + // TODO(fwg): Remove this check when cnnl fix the bug that groups > 1. + PADDLE_ENFORCE_EQ( + groups == 1, true, + platform::errors::InvalidArgument( + "MLU deformable_conv kernel only support groups == 1, but get %d.", + groups)); + + // transform paddings from {h, w} to {top, bottom, left, right}. + const std::vector trans_paddings{paddings[0], paddings[0], paddings[1], + paddings[1]}; + MLUCnnlDCNDesc dcn_desc(input->dims().size(), trans_paddings.data(), + strides.data(), dilations.data(), deformable_groups, + groups, im2col_step); + + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + Tensor trans_input(input->dtype()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, input, &trans_input, + true /*need_reshape_or_alloc*/); + + Tensor trans_offset(offset->dtype()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, offset, &trans_offset, + true /*need_reshape_or_alloc*/); + + Tensor trans_mask(mask->dtype()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, mask, &trans_mask, + true /*need_reshape_or_alloc*/); + + Tensor trans_filter(filter->dtype()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, filter, &trans_filter, + true /*need_reshape_or_alloc*/); + + Tensor tmp_output(output->dtype()); + auto output_dims = output->dims(); + tmp_output.mutable_data( + {output_dims[0], output_dims[2], output_dims[3], output_dims[1]}, + ctx.GetPlace()); + + cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc input_desc(trans_input, data_layout, + ToCnnlDataType(trans_input.dtype())); + MLUCnnlTensorDesc offset_desc(trans_offset, data_layout, + ToCnnlDataType(trans_offset.dtype())); + MLUCnnlTensorDesc mask_desc(trans_mask, data_layout, + ToCnnlDataType(trans_mask.dtype())); + MLUCnnlTensorDesc filter_desc(trans_filter, data_layout, + ToCnnlDataType(trans_filter.dtype())); + MLUCnnlTensorDesc output_desc(tmp_output, data_layout, + ToCnnlDataType(tmp_output.dtype())); + MLUCnnl::DCNForward( + ctx, dcn_desc.get(), input_desc.get(), GetBasePtr(&trans_input), + offset_desc.get(), GetBasePtr(&trans_offset), mask_desc.get(), + GetBasePtr(&trans_mask), filter_desc.get(), GetBasePtr(&trans_filter), + nullptr, nullptr, output_desc.get(), GetBasePtr(&tmp_output)); + + const std::vector perm_to_nchw = {0, 3, 1, 2}; + TransposeFromMLUTensor(ctx, perm_to_nchw, &tmp_output, output, + false /*need_reshape_or_alloc*/); + } +}; + +template +class DeformableConvGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + auto* input_grad = ctx.Output(framework::GradVarName("Input")); + auto* filter_grad = ctx.Output(framework::GradVarName("Filter")); + auto* offset_grad = ctx.Output(framework::GradVarName("Offset")); + auto* mask_grad = ctx.Output(framework::GradVarName("Mask")); + + const Tensor* input = ctx.Input("Input"); + auto* offset = ctx.Input("Offset"); + auto* mask = ctx.Input("Mask"); + auto* filter = ctx.Input("Filter"); + + int groups = ctx.Attr("groups"); + int deformable_groups = ctx.Attr("deformable_groups"); + int im2col_step = ctx.Attr("im2col_step"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + // TODO(fwg): Remove this check when cnnl fix the bug that groups > 1. + PADDLE_ENFORCE_EQ(groups == 1, true, + platform::errors::InvalidArgument( + "MLU deformable_conv_grad kernel only support groups " + "== 1, but get %d.", + groups)); + + // transform paddings from {h, w} to {top, bottom, left, right}. + const std::vector trans_paddings{paddings[0], paddings[0], paddings[1], + paddings[1]}; + MLUCnnlDCNDesc dcn_desc(input->dims().size(), trans_paddings.data(), + strides.data(), dilations.data(), deformable_groups, + groups, im2col_step); + + Tensor tmp_input_grad; + auto input_dims = input->dims(); + tmp_input_grad.mutable_data( + {input_dims[0], input_dims[2], input_dims[3], input_dims[1]}, + ctx.GetPlace()); + + Tensor tmp_filter_grad; + auto filter_dims = filter->dims(); + tmp_filter_grad.mutable_data( + {filter_dims[0], filter_dims[2], filter_dims[3], filter_dims[1]}, + ctx.GetPlace()); + + Tensor tmp_offset_grad; + auto offset_dims = offset->dims(); + tmp_offset_grad.mutable_data( + {offset_dims[0], offset_dims[2], offset_dims[3], offset_dims[1]}, + ctx.GetPlace()); + + Tensor tmp_mask_grad; + auto mask_dims = mask->dims(); + tmp_mask_grad.mutable_data( + {mask_dims[0], mask_dims[2], mask_dims[3], mask_dims[1]}, + ctx.GetPlace()); + + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + Tensor trans_output_grad(output_grad->dtype()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, output_grad, + &trans_output_grad, + true /*need_reshape_or_alloc*/); + + Tensor trans_input(input->dtype()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, input, &trans_input, + true /*need_reshape_or_alloc*/); + + Tensor trans_offset(offset->dtype()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, offset, &trans_offset, + true /*need_reshape_or_alloc*/); + + Tensor trans_mask(mask->dtype()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, mask, &trans_mask, + true /*need_reshape_or_alloc*/); + + Tensor trans_filter(filter->dtype()); + TransposeFromMLUTensor(ctx, perm_to_nhwc, filter, &trans_filter, + true /*need_reshape_or_alloc*/); + + cnnlTensorLayout_t data_layout = CNNL_LAYOUT_NHWC; + MLUCnnlTensorDesc output_grad_desc( + trans_output_grad, data_layout, + ToCnnlDataType(trans_output_grad.dtype())); + MLUCnnlTensorDesc input_desc(trans_input, data_layout, + ToCnnlDataType(trans_input.dtype())); + MLUCnnlTensorDesc offset_desc(trans_offset, data_layout, + ToCnnlDataType(trans_offset.dtype())); + MLUCnnlTensorDesc mask_desc(trans_mask, data_layout, + ToCnnlDataType(trans_mask.dtype())); + MLUCnnlTensorDesc filter_desc(trans_filter, data_layout, + ToCnnlDataType(trans_filter.dtype())); + + MLUCnnl::DCNBackwardData( + ctx, dcn_desc.get(), input_desc.get(), GetBasePtr(&trans_input), + offset_desc.get(), GetBasePtr(&trans_offset), mask_desc.get(), + GetBasePtr(&trans_mask), filter_desc.get(), GetBasePtr(&trans_filter), + output_grad_desc.get(), GetBasePtr(&trans_output_grad), + input_desc.get(), GetBasePtr(&tmp_input_grad), offset_desc.get(), + GetBasePtr(&tmp_offset_grad), mask_desc.get(), + GetBasePtr(&tmp_mask_grad)); + + MLUCnnl::DCNBackwardWeight( + ctx, dcn_desc.get(), input_desc.get(), GetBasePtr(&trans_input), + offset_desc.get(), GetBasePtr(&trans_offset), mask_desc.get(), + GetBasePtr(&trans_mask), output_grad_desc.get(), + GetBasePtr(&trans_output_grad), filter_desc.get(), + GetBasePtr(&tmp_filter_grad), nullptr, nullptr); + + const std::vector perm_to_nchw = {0, 3, 1, 2}; + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + TransposeFromMLUTensor(ctx, perm_to_nchw, &tmp_input_grad, input_grad, + false /*need_reshape_or_alloc*/); + } + + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + TransposeFromMLUTensor(ctx, perm_to_nchw, &tmp_filter_grad, + filter_grad, false /*need_reshape_or_alloc*/); + } + + if (offset_grad) { + offset_grad->mutable_data(ctx.GetPlace()); + TransposeFromMLUTensor(ctx, perm_to_nchw, &tmp_offset_grad, + offset_grad, false /*need_reshape_or_alloc*/); + } + + if (mask_grad) { + mask_grad->mutable_data(ctx.GetPlace()); + TransposeFromMLUTensor(ctx, perm_to_nchw, &tmp_mask_grad, mask_grad, + false /*need_reshape_or_alloc*/); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(deformable_conv, ops::DeformableConvMLUKernel); +REGISTER_OP_MLU_KERNEL(deformable_conv_grad, + ops::DeformableConvGradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 89ea065e21c7f2f8faac54db1871664f4247a9df..445a9fecff74b464cd4be3dce34e39859540a3a1 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -506,6 +506,24 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { } } +MLUCnnlDCNDesc::MLUCnnlDCNDesc(int dimNb, const int* pad, const int* stride, + const int* dilation, int deformable_group, + int conv_group, int im2col_step) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateDCNDescriptor(&dcn_desc_)); + const cnnlDataType_t compute_type = CNNL_DTYPE_FLOAT; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetDCNDescriptor( + dcn_desc_, dimNb, pad, stride, dilation, deformable_group, conv_group, + im2col_step, compute_type)); +} + +const cnnlDCNDescriptor_t MLUCnnlDCNDesc::get() const { return dcn_desc_; } + +MLUCnnlDCNDesc::~MLUCnnlDCNDesc() { + if (dcn_desc_) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroyDCNDescriptor(dcn_desc_)); + } +} + /* static */ void MLUCnnl::Active(const ExecutionContext& ctx, cnnlActivationDescriptor_t active_desc, const cnnlTensorDescriptor_t input_desc, @@ -2488,6 +2506,88 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { workspace_size, nullptr /*beta*/, filter_backprop_desc, filter_backprop)); } +/* static */ void MLUCnnl::DCNForward( + const ExecutionContext& ctx, const cnnlDCNDescriptor_t dcn_desc, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t offset_desc, const void* offset, + const cnnlTensorDescriptor_t mask_desc, const void* mask, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t bias_desc, const void* bias, + const cnnlTensorDescriptor_t output_desc, void* output) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size = 0; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetDCNForwardWorkspaceSize( + handle, dcn_desc, input_desc, offset_desc, mask_desc, weight_desc, + bias_desc, output_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlDCNForward(handle, dcn_desc, input_desc, input, offset_desc, offset, + mask_desc, mask, weight_desc, weight, bias_desc, bias, + workspace_ptr, workspace_size, output_desc, output)); +} + +/* static */ void MLUCnnl::DCNBackwardData( + const ExecutionContext& ctx, const cnnlDCNDescriptor_t dcn_desc, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t offset_desc, const void* offset, + const cnnlTensorDescriptor_t mask_desc, const void* mask, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t grad_output_desc, const void* grad_output, + const cnnlTensorDescriptor_t grad_input_desc, void* grad_input, + const cnnlTensorDescriptor_t grad_offset_desc, void* grad_offset, + const cnnlTensorDescriptor_t grad_mask_desc, void* grad_mask) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size = 0; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetDCNBakcwardDataWorkspaceSize( + handle, dcn_desc, input_desc, offset_desc, mask_desc, weight_desc, + grad_output_desc, grad_input_desc, grad_offset_desc, grad_mask_desc, + &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlDCNBackwardData( + handle, dcn_desc, input_desc, input, offset_desc, offset, mask_desc, mask, + weight_desc, weight, grad_output_desc, grad_output, workspace_ptr, + workspace_size, grad_input_desc, grad_input, grad_offset_desc, + grad_offset, grad_mask_desc, grad_mask)); +} + +/* static */ void MLUCnnl::DCNBackwardWeight( + const ExecutionContext& ctx, const cnnlDCNDescriptor_t dcn_desc, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t offset_desc, const void* offset, + const cnnlTensorDescriptor_t mask_desc, const void* mask, + const cnnlTensorDescriptor_t grad_output_desc, const void* grad_output, + const cnnlTensorDescriptor_t grad_weight_desc, void* grad_weight, + const cnnlTensorDescriptor_t grad_bias_desc, void* grad_bias) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size = 0; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetDCNBackwardWeightWorkspaceSize( + handle, dcn_desc, input_desc, offset_desc, mask_desc, grad_output_desc, + grad_weight_desc, grad_bias_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlDCNBackwardWeight( + handle, dcn_desc, input_desc, input, offset_desc, offset, mask_desc, mask, + grad_output_desc, grad_output, workspace_ptr, workspace_size, + grad_weight_desc, grad_weight, grad_bias_desc, grad_bias)); +} + /* static */ void MLUCnnl::QuantizeMatMul( const ExecutionContext& ctx, const bool transpose_a, const bool transpose_b, const cnnlTensorDescriptor_t a_desc, const void* a, const void* a_position, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index b8fa0cbdd4a8ccba8af206f680ff9a48008f4596..4ba7ae5ac6e9ef041673b62cca538502c7ebc209 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -444,6 +444,19 @@ class MLUCnnlTrigonDesc { cnnlTrigonDescriptor_t trigon_desc_ = nullptr; }; +class MLUCnnlDCNDesc { + public: + MLUCnnlDCNDesc(int dimNb, const int* pad, const int* stride, + const int* dilation, int deformable_group, int conv_group, + int im2col_step); + const cnnlDCNDescriptor_t get() const; + + ~MLUCnnlDCNDesc(); + + private: + cnnlDCNDescriptor_t dcn_desc_ = nullptr; +}; + class MLUCnnl { public: static void Active(const ExecutionContext& ctx, @@ -1233,6 +1246,35 @@ class MLUCnnl { const cnnlTensorDescriptor_t out_backprop_desc, const void* out_backprop, const cnnlTensorDescriptor_t filter_backprop_desc, void* filter_backprop); + static void DCNForward( + const ExecutionContext& ctx, const cnnlDCNDescriptor_t dcn_desc, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t offset_desc, const void* offset, + const cnnlTensorDescriptor_t mask_desc, const void* mask, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t bias_desc, const void* bias, + const cnnlTensorDescriptor_t output_desc, void* output); + + static void DCNBackwardData( + const ExecutionContext& ctx, const cnnlDCNDescriptor_t dcn_desc, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t offset_desc, const void* offset, + const cnnlTensorDescriptor_t mask_desc, const void* mask, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t grad_output_desc, const void* grad_output, + const cnnlTensorDescriptor_t grad_input_desc, void* grad_input, + const cnnlTensorDescriptor_t grad_offset_desc, void* grad_offset, + const cnnlTensorDescriptor_t grad_mask_desc, void* grad_mask); + + static void DCNBackwardWeight( + const ExecutionContext& ctx, const cnnlDCNDescriptor_t dcn_desc, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t offset_desc, const void* offset, + const cnnlTensorDescriptor_t mask_desc, const void* mask, + const cnnlTensorDescriptor_t grad_output_desc, const void* grad_output, + const cnnlTensorDescriptor_t grad_weight_desc, void* grad_weight, + const cnnlTensorDescriptor_t grad_bias_desc, void* grad_bias); + static void InTopK(const ExecutionContext& ctx, const cnnlTensorDescriptor_t predictions_desc, const void* predictions, diff --git a/python/paddle/fluid/tests/unittests/mlu/test_deformable_conv_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_deformable_conv_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..65cf070bf22660ede579c063945d409c7e7b3079 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_deformable_conv_op_mlu.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import unittest +import numpy as np +import paddle.fluid.core as core +import paddle.fluid as fluid + +import sys + +sys.path.append('..') +from op_test import OpTest +from test_deformable_conv_op import dconv_im2col_gemm, deform_conv2d_wrapper + +paddle.enable_static() + + +class TestModulatedDeformableConvOp(OpTest): + + def setUp(self): + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.python_api = deform_conv2d_wrapper + self.op_type = "deformable_conv" + self.init_type() + self.init_group() + self.init_dilation() + self.init_test_case() + + conv_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + offset = 10 * np.random.random(self.offset_size).astype(self.dtype) + mask = 10 * np.random.random(self.mask_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) + + output = dconv_im2col_gemm(input, offset, mask, filter, self.groups, + conv_param) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Offset': OpTest.np_dtype_to_fluid_dtype(offset), + 'Mask': OpTest.np_dtype_to_fluid_dtype(mask), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'deformable_groups': self.deformable_groups, + 'im2col_step': self.im2col_step, + 'dilations': self.dilations, + } + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False) + + def test_check_grad(self): + self.check_grad_with_place(self.place, + {'Input', 'Offset', 'Mask', 'Filter'}, + 'Output', + max_relative_error=0.05, + check_eager=False) + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 8, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [4, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + def init_type(self): + self.dtype = np.float32 + + +class TestWithStride(TestModulatedDeformableConvOp): + + def init_test_case(self): + self.pad = [3, 3] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + +class TestWithDilation(TestModulatedDeformableConvOp): + + def init_test_case(self): + self.pad = [2, 2] + self.stride = [1, 1] + self.input_size = [4, 3, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [2, 2] + + +class TestWith3x3(TestModulatedDeformableConvOp): + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + +if __name__ == '__main__': + unittest.main()