diff --git a/paddle/fluid/operators/deformable_conv_op_xpu.cc b/paddle/fluid/operators/deformable_conv_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..8dc5e59ee95716b90dc82264383a5e557c7bbe39 --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_op_xpu.cc @@ -0,0 +1,288 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/xpu_header.h" +#include "xpu/refactor/math.h" +#include "xpu/refactor/nn.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class DeformableConvXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* offset = ctx.Input("Offset"); + auto* mask = ctx.Input("Mask"); + Tensor filter = *ctx.Input("Filter"); + Tensor* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + const int groups = ctx.Attr("groups"); + const int deformable_groups = ctx.Attr("deformable_groups"); + const int im2col_step = ctx.Attr("im2col_step"); + const std::vector strides = ctx.Attr>("strides"); + const std::vector paddings = ctx.Attr>("paddings"); + const std::vector dilations = ctx.Attr>("dilations"); + + PADDLE_ENFORCE_EQ( + deformable_groups == 1, true, + platform::errors::InvalidArgument(( + "XPU only support deformable_groups == 1 in deformable_conv op."))); + PADDLE_ENFORCE_EQ( + groups == 1, true, + platform::errors::InvalidArgument( + ("XPU only support groups == 1 in deformable_conv op."))); + PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8, true, + platform::errors::InvalidArgument( + "Filter high and weight should less than 8 on xpu " + "in deformable_conv op.")); + + const int batch_size = static_cast(input->dims()[0]); + std::vector output_shape_vec(framework::vectorize(output->dims())); + + const T* input_ptr = input->data(); + const T* filter_ptr = filter.data(); + const float* offset_ptr = offset->data(); + const float* mask_ptr = mask->data(); + T* output_prt = output->data(); + + // set zeros for d_table_data + const int zero = 0; + int r = xpu::constant(dev_ctx.x_context(), output_prt, output->numel(), + zero); + PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, + platform::errors::External( + "XPU API return wrong value[%d], please check where " + "Baidu Kunlun Card is properly installed.", + r)); + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset->numel() / offset->dims()[0]; + int input_mask_dim = mask->numel() / mask->dims()[0]; + int output_dim = + output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3]; + std::vector ksize{static_cast(filter.dims()[2]), + static_cast(filter.dims()[3])}; + int n = im2col_step; + int c = input->dims()[1]; + int h = input->dims()[2]; + int w = input->dims()[3]; + int f = filter.dims()[0]; + + for (int i = 0; i < batch_size / im2col_step; ++i) { + int r = xpu::deformable_conv( + dev_ctx.x_context(), input_ptr + i * im2col_step * input_dim, + filter_ptr, offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, + output_prt + i * im2col_step * output_dim, n, c, h, w, f, ksize, + strides, paddings, dilations, groups, deformable_groups, nullptr, + nullptr, nullptr, true); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU deformable_conv kernel return wrong value[%d].", r)); + } + } +}; + +template +class DeformableConvGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); + Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); + Tensor* mask_grad = ctx.Output(framework::GradVarName("Mask")); + T* dx_data = nullptr; + T* dw_data = nullptr; + T* dmask_data = nullptr; + T* doffset_data = nullptr; + + if (input_grad != nullptr) { + input_grad->mutable_data(ctx.GetPlace()); + dx_data = input_grad->data(); + } + if (filter_grad != nullptr) { + filter_grad->mutable_data(ctx.GetPlace()); + dw_data = filter_grad->data(); + } + if (offset_grad != nullptr) { + offset_grad->mutable_data(ctx.GetPlace()); + doffset_data = offset_grad->data(); + } + if (mask_grad != nullptr) { + mask_grad->mutable_data(ctx.GetPlace()); + dmask_data = mask_grad->data(); + } + + const Tensor* input = ctx.Input("Input"); + Tensor offset = *ctx.Input("Offset"); + Tensor mask = *ctx.Input("Mask"); + Tensor filter = *ctx.Input("Filter"); + + int groups = ctx.Attr("groups"); + int deformable_groups = ctx.Attr("deformable_groups"); + int im2col_step = ctx.Attr("im2col_step"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + PADDLE_ENFORCE_EQ( + deformable_groups == 1, true, + platform::errors::InvalidArgument(( + "XPU only support deformable_groups == 1 in deformable_conv op."))); + PADDLE_ENFORCE_EQ( + groups == 1, true, + platform::errors::InvalidArgument( + ("XPU only support groups == 1 in deformable_conv op."))); + PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8, true, + platform::errors::InvalidArgument( + "Filter high and weight should less than 8 on xpu " + "in deformable_conv op.")); + + auto& dev_ctx = ctx.template device_context(); + const int batch_size = static_cast(input->dims()[0]); + std::vector output_shape_vec( + framework::vectorize(output_grad->dims())); + const T* output_grad_ptr = output_grad->data(); + const T* input_ptr = input->data(); + const T* filter_ptr = filter.data(); + const float* offset_ptr = offset.data(); + const float* mask_ptr = mask.data(); + if (dx_data == nullptr) { + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&dx_data), + input->numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + } + if (dw_data == nullptr) { + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&dw_data), + filter.numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + } + if (doffset_data == nullptr) { + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&doffset_data), + offset.numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + } + if (dmask_data == nullptr) { + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&dmask_data), + mask.numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + } + + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset.numel() / offset.dims()[0]; + int input_mask_dim = mask.numel() / mask.dims()[0]; + int output_dim = + output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3]; + std::vector ksize{static_cast(filter.dims()[2]), + static_cast(filter.dims()[3])}; + int n = im2col_step; + int c = input->dims()[1]; + int h = input->dims()[2]; + int w = input->dims()[3]; + int f = filter.dims()[0]; + + T* filter_grad_tmp = nullptr; + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&filter_grad_tmp), + filter_grad->numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + + // set zeros for d_table_data + const int zero = 0; + int r_dx = + xpu::constant(dev_ctx.x_context(), dx_data, input->numel(), zero); + int r_dw = + xpu::constant(dev_ctx.x_context(), dw_data, filter.numel(), zero); + int r_doffset = xpu::constant(dev_ctx.x_context(), doffset_data, + offset.numel(), zero); + int r_dmask = + xpu::constant(dev_ctx.x_context(), dmask_data, mask.numel(), zero); + int r_filter = xpu::constant(dev_ctx.x_context(), filter_grad_tmp, + filter.numel(), zero); + auto ret = (r_dx == xpu::Error_t::SUCCESS) && (r_dx == r_dw) && + (r_dx == r_doffset) && (r_dx == r_dmask) && (r_dx == r_filter); + PADDLE_ENFORCE_EQ(ret, true, + platform::errors::External( + "XPU API return wrong value, please check where " + "Baidu Kunlun Card is properly installed.")); + + for (int i = 0; i < batch_size / im2col_step; ++i) { + int r = xpu::deformable_conv_grad( + dev_ctx.x_context(), input_ptr + i * im2col_step * input_dim, + filter_ptr, offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, + output_grad_ptr + i * im2col_step * output_dim, + dx_data + i * im2col_step * input_dim, filter_grad_tmp, + doffset_data + i * im2col_step * input_offset_dim, + dmask_data + i * im2col_step * input_mask_dim, n, c, h, w, f, ksize, + strides, paddings, dilations, groups, deformable_groups, nullptr, + nullptr, nullptr, nullptr, nullptr, true); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU deformable_conv_grad kernel return wrong value[%d].", r)); + r = baidu::xpu::api::add(dev_ctx.x_context(), filter_grad_tmp, dw_data, + dw_data, filter.numel()); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU add kernel return wrong value[%d].", r)); + } + + dev_ctx.Wait(); + xpu_free(filter_grad_tmp); + if (input_grad == nullptr) { + xpu_free(dx_data); + } + if (filter_grad == nullptr) { + xpu_free(dw_data); + } + if (offset_grad == nullptr) { + xpu_free(doffset_data); + } + if (mask_grad == nullptr) { + xpu_free(dmask_data); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using XPUDeviceContext = paddle::platform::XPUDeviceContext; + +REGISTER_OP_XPU_KERNEL(deformable_conv, + ops::DeformableConvXPUKernel); +REGISTER_OP_XPU_KERNEL( + deformable_conv_grad, + ops::DeformableConvGradXPUKernel); + +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_deformable_conv_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_deformable_conv_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..5c611b629988884e07d9593de7f89a6121a119a6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_deformable_conv_op_xpu.py @@ -0,0 +1,274 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +sys.path.append("..") +import unittest +import numpy as np + +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test_xpu import OpTest, XPUOpTest +import paddle +from paddle.fluid import Program, program_guard + + +def dmc_bilinear(data_im, height, width, h, w): + h_low = int(np.floor(h)) + w_low = int(np.floor(w)) + h_high = h_low + 1 + w_high = w_low + 1 + + lh = h - h_low + lw = w - w_low + hh = 1 - lh + hw = 1 - lw + + v1 = 0 + if h_low >= 0 and w_low >= 0: + v1 = data_im[h_low, w_low] + v2 = 0 + if h_low >= 0 and w_high <= width - 1: + v2 = data_im[h_low, w_high] + v3 = 0 + if h_high <= height - 1 and w_low >= 0: + v3 = data_im[h_high, w_low] + v4 = 0 + if h_high <= height - 1 and w_high <= width - 1: + v4 = data_im[h_high, w_high] + + w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + + return val + + +def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param): + in_n, in_c, in_h, in_w = input.shape + out_c, f_c, f_h, f_w = filter.shape + + assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w) + assert mask.shape == (in_n, f_h * f_w, in_h, in_w) + assert f_c * group == in_c + assert np.mod(out_c, group) == 0 + + stride, pad, dilation = conv_param['stride'], conv_param['pad'],\ + conv_param['dilation'] + out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0] + out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1] + assert out_h == in_h + assert out_w == in_w + + col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w)) + for n in range(in_n): + for c in range(in_c): + for h in range(out_h): + for w in range(out_w): + for kh in range(f_h): + for kw in range(f_w): + offset_h_table = \ + offset[n, ::2, h, w].reshape(f_h, f_w) + offset_w_table = \ + offset[n, 1::2, h, w].reshape(f_h, f_w) + mask_table = \ + mask[n, :, h, w].reshape(f_h, f_w) + offset_h = offset_h_table[kh, kw] + offset_w = offset_w_table[kh, kw] + val = 0 + im_h = h * stride[0] + kh * dilation[0] \ + + offset_h - pad[0] + im_w = w * stride[0] + kw * dilation[0] \ + + offset_w - pad[1] + if im_h > -1 and im_w > -1 and \ + im_h < in_h and im_w < in_h: + val = dmc_bilinear(input[n, c], in_h, in_w, + im_h, im_w) + val_out = val * mask_table[kh, kw] + col_buffer[n, c * f_h * f_w + kh * f_w + kw, h * + in_w + w] = val_out + + out = np.zeros((in_n, group, int(out_c // group), out_h * out_w)) + weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w) + col_buffer = col_buffer.reshape( + (in_n, group, int(in_c // group * f_h * f_w), in_h * in_w)) + for n in range(in_n): + for g in range(group): + out[n, g] = np.matmul(weight[g], col_buffer[n, g]) + out = out.reshape(in_n, out_c, out_h, out_w) + return out + + +class TestModulatedDeformableConvOp(XPUOpTest): + def setUp(self): + self.op_type = "deformable_conv" + self.dtype = np.float32 + self.init_group() + self.init_dilation() + self.init_test_case() + + conv_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + offset = 10 * np.random.random(self.offset_size).astype(self.dtype) + mask = 10 * np.random.random(self.mask_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) + output = dconv_im2col_gemm(input, offset, mask, filter, self.groups, + conv_param) + output = output.astype(self.dtype) + + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Offset': OpTest.np_dtype_to_fluid_dtype(offset), + 'Mask': OpTest.np_dtype_to_fluid_dtype(mask), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'deformable_groups': self.deformable_groups, + 'im2col_step': self.im2col_step, + 'dilations': self.dilations, + } + self.outputs = {'Output': output} + + def has_cuda(self): + return core.is_compiled_with_cuda() and (self.use_cudnn or + self.use_cuda) + + def test_check_output(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if core.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, {'Input', 'Offset', 'Mask', 'Filter'}, + 'Output', + max_relative_error=0.06) + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 8, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [8, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + +class TestWithDilation(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [1, 1] + self.input_size = [4, 3, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [2, 2] + + +class TestWith3x3(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + mask_c = self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + self.mask_size = [ + self.input_size[0], mask_c, self.input_size[2], self.input_size[3] + ] + + +class TestModulatedDeformableConvInvalidInput(unittest.TestCase): + def test_error(self): + def test_invalid_input(): + paddle.enable_static() + input = [1, 3, 32, 32] + offset = fluid.data( + name='offset', shape=[None, 3, 32, 32], dtype='float32') + mask = fluid.data( + name='mask', shape=[None, 3, 32, 32], dtype='float32') + loss = fluid.layers.deformable_conv( + input, offset, mask, num_filters=4, filter_size=1) + + self.assertRaises(TypeError, test_invalid_input) + + def test_invalid_offset(): + paddle.enable_static() + input = fluid.data( + name='input', shape=[None, 3, 32, 32], dtype='int32') + offset = fluid.data( + name='offset', shape=[None, 3, 32, 32], dtype='float32') + mask = fluid.data( + name='mask', shape=[None, 3, 32, 32], dtype='float32') + loss = fluid.layers.deformable_conv( + input, offset, mask, num_filters=4, filter_size=1) + + self.assertRaises(TypeError, test_invalid_offset) + + +if __name__ == '__main__': + unittest.main()