未验证 提交 07c67d5a 编写于 作者: 卖鱼的哲学 提交者: GitHub

add deformable_conv op on xpu (#29234)

* rebase develop

* update deformable_conv op on xpu

* update deformable_conv op on xpu
上级 1de32f82
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/xpu_header.h"
#include "xpu/refactor/math.h"
#include "xpu/refactor/nn.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class DeformableConvXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* offset = ctx.Input<Tensor>("Offset");
auto* mask = ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int groups = ctx.Attr<int>("groups");
const int deformable_groups = ctx.Attr<int>("deformable_groups");
const int im2col_step = ctx.Attr<int>("im2col_step");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
const std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
const std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
PADDLE_ENFORCE_EQ(
deformable_groups == 1, true,
platform::errors::InvalidArgument((
"XPU only support deformable_groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(
groups == 1, true,
platform::errors::InvalidArgument(
("XPU only support groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8, true,
platform::errors::InvalidArgument(
"Filter high and weight should less than 8 on xpu "
"in deformable_conv op."));
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
const T* input_ptr = input->data<T>();
const T* filter_ptr = filter.data<T>();
const float* offset_ptr = offset->data<T>();
const float* mask_ptr = mask->data<T>();
T* output_prt = output->data<T>();
// set zeros for d_table_data
const int zero = 0;
int r = xpu::constant<T>(dev_ctx.x_context(), output_prt, output->numel(),
zero);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External(
"XPU API return wrong value[%d], please check where "
"Baidu Kunlun Card is properly installed.",
r));
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset->numel() / offset->dims()[0];
int input_mask_dim = mask->numel() / mask->dims()[0];
int output_dim =
output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3];
std::vector<int> ksize{static_cast<int>(filter.dims()[2]),
static_cast<int>(filter.dims()[3])};
int n = im2col_step;
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
int f = filter.dims()[0];
for (int i = 0; i < batch_size / im2col_step; ++i) {
int r = xpu::deformable_conv<float, float, float, int>(
dev_ctx.x_context(), input_ptr + i * im2col_step * input_dim,
filter_ptr, offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim,
output_prt + i * im2col_step * output_dim, n, c, h, w, f, ksize,
strides, paddings, dilations, groups, deformable_groups, nullptr,
nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU deformable_conv kernel return wrong value[%d].", r));
}
}
};
template <typename DeviceContext, typename T>
class DeformableConvGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* output_grad =
ctx.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
Tensor* offset_grad = ctx.Output<Tensor>(framework::GradVarName("Offset"));
Tensor* mask_grad = ctx.Output<Tensor>(framework::GradVarName("Mask"));
T* dx_data = nullptr;
T* dw_data = nullptr;
T* dmask_data = nullptr;
T* doffset_data = nullptr;
if (input_grad != nullptr) {
input_grad->mutable_data<T>(ctx.GetPlace());
dx_data = input_grad->data<T>();
}
if (filter_grad != nullptr) {
filter_grad->mutable_data<T>(ctx.GetPlace());
dw_data = filter_grad->data<T>();
}
if (offset_grad != nullptr) {
offset_grad->mutable_data<T>(ctx.GetPlace());
doffset_data = offset_grad->data<T>();
}
if (mask_grad != nullptr) {
mask_grad->mutable_data<T>(ctx.GetPlace());
dmask_data = mask_grad->data<T>();
}
const Tensor* input = ctx.Input<Tensor>("Input");
Tensor offset = *ctx.Input<Tensor>("Offset");
Tensor mask = *ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
int groups = ctx.Attr<int>("groups");
int deformable_groups = ctx.Attr<int>("deformable_groups");
int im2col_step = ctx.Attr<int>("im2col_step");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
PADDLE_ENFORCE_EQ(
deformable_groups == 1, true,
platform::errors::InvalidArgument((
"XPU only support deformable_groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(
groups == 1, true,
platform::errors::InvalidArgument(
("XPU only support groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8, true,
platform::errors::InvalidArgument(
"Filter high and weight should less than 8 on xpu "
"in deformable_conv op."));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> output_shape_vec(
framework::vectorize(output_grad->dims()));
const T* output_grad_ptr = output_grad->data<T>();
const T* input_ptr = input->data<T>();
const T* filter_ptr = filter.data<T>();
const float* offset_ptr = offset.data<float>();
const float* mask_ptr = mask.data<float>();
if (dx_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&dx_data),
input->numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
}
if (dw_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&dw_data),
filter.numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
}
if (doffset_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&doffset_data),
offset.numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
}
if (dmask_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&dmask_data),
mask.numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
}
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset.numel() / offset.dims()[0];
int input_mask_dim = mask.numel() / mask.dims()[0];
int output_dim =
output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3];
std::vector<int> ksize{static_cast<int>(filter.dims()[2]),
static_cast<int>(filter.dims()[3])};
int n = im2col_step;
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
int f = filter.dims()[0];
T* filter_grad_tmp = nullptr;
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&filter_grad_tmp),
filter_grad->numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
// set zeros for d_table_data
const int zero = 0;
int r_dx =
xpu::constant<T>(dev_ctx.x_context(), dx_data, input->numel(), zero);
int r_dw =
xpu::constant<T>(dev_ctx.x_context(), dw_data, filter.numel(), zero);
int r_doffset = xpu::constant<T>(dev_ctx.x_context(), doffset_data,
offset.numel(), zero);
int r_dmask =
xpu::constant<T>(dev_ctx.x_context(), dmask_data, mask.numel(), zero);
int r_filter = xpu::constant<T>(dev_ctx.x_context(), filter_grad_tmp,
filter.numel(), zero);
auto ret = (r_dx == xpu::Error_t::SUCCESS) && (r_dx == r_dw) &&
(r_dx == r_doffset) && (r_dx == r_dmask) && (r_dx == r_filter);
PADDLE_ENFORCE_EQ(ret, true,
platform::errors::External(
"XPU API return wrong value, please check where "
"Baidu Kunlun Card is properly installed."));
for (int i = 0; i < batch_size / im2col_step; ++i) {
int r = xpu::deformable_conv_grad<float, float, float, int>(
dev_ctx.x_context(), input_ptr + i * im2col_step * input_dim,
filter_ptr, offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim,
output_grad_ptr + i * im2col_step * output_dim,
dx_data + i * im2col_step * input_dim, filter_grad_tmp,
doffset_data + i * im2col_step * input_offset_dim,
dmask_data + i * im2col_step * input_mask_dim, n, c, h, w, f, ksize,
strides, paddings, dilations, groups, deformable_groups, nullptr,
nullptr, nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU deformable_conv_grad kernel return wrong value[%d].", r));
r = baidu::xpu::api::add<T>(dev_ctx.x_context(), filter_grad_tmp, dw_data,
dw_data, filter.numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU add kernel return wrong value[%d].", r));
}
dev_ctx.Wait();
xpu_free(filter_grad_tmp);
if (input_grad == nullptr) {
xpu_free(dx_data);
}
if (filter_grad == nullptr) {
xpu_free(dw_data);
}
if (offset_grad == nullptr) {
xpu_free(doffset_data);
}
if (mask_grad == nullptr) {
xpu_free(dmask_data);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using XPUDeviceContext = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(deformable_conv,
ops::DeformableConvXPUKernel<XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
deformable_conv_grad,
ops::DeformableConvGradXPUKernel<XPUDeviceContext, float>);
#endif
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test_xpu import OpTest, XPUOpTest
import paddle
from paddle.fluid import Program, program_guard
def dmc_bilinear(data_im, height, width, h, w):
h_low = int(np.floor(h))
w_low = int(np.floor(w))
h_high = h_low + 1
w_high = w_low + 1
lh = h - h_low
lw = w - w_low
hh = 1 - lh
hw = 1 - lw
v1 = 0
if h_low >= 0 and w_low >= 0:
v1 = data_im[h_low, w_low]
v2 = 0
if h_low >= 0 and w_high <= width - 1:
v2 = data_im[h_low, w_high]
v3 = 0
if h_high <= height - 1 and w_low >= 0:
v3 = data_im[h_high, w_low]
v4 = 0
if h_high <= height - 1 and w_high <= width - 1:
v4 = data_im[h_high, w_high]
w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param):
in_n, in_c, in_h, in_w = input.shape
out_c, f_c, f_h, f_w = filter.shape
assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w)
assert mask.shape == (in_n, f_h * f_w, in_h, in_w)
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
stride, pad, dilation = conv_param['stride'], conv_param['pad'],\
conv_param['dilation']
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1]
assert out_h == in_h
assert out_w == in_w
col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w))
for n in range(in_n):
for c in range(in_c):
for h in range(out_h):
for w in range(out_w):
for kh in range(f_h):
for kw in range(f_w):
offset_h_table = \
offset[n, ::2, h, w].reshape(f_h, f_w)
offset_w_table = \
offset[n, 1::2, h, w].reshape(f_h, f_w)
mask_table = \
mask[n, :, h, w].reshape(f_h, f_w)
offset_h = offset_h_table[kh, kw]
offset_w = offset_w_table[kh, kw]
val = 0
im_h = h * stride[0] + kh * dilation[0] \
+ offset_h - pad[0]
im_w = w * stride[0] + kw * dilation[0] \
+ offset_w - pad[1]
if im_h > -1 and im_w > -1 and \
im_h < in_h and im_w < in_h:
val = dmc_bilinear(input[n, c], in_h, in_w,
im_h, im_w)
val_out = val * mask_table[kh, kw]
col_buffer[n, c * f_h * f_w + kh * f_w + kw, h *
in_w + w] = val_out
out = np.zeros((in_n, group, int(out_c // group), out_h * out_w))
weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w)
col_buffer = col_buffer.reshape(
(in_n, group, int(in_c // group * f_h * f_w), in_h * in_w))
for n in range(in_n):
for g in range(group):
out[n, g] = np.matmul(weight[g], col_buffer[n, g])
out = out.reshape(in_n, out_c, out_h, out_w)
return out
class TestModulatedDeformableConvOp(XPUOpTest):
def setUp(self):
self.op_type = "deformable_conv"
self.dtype = np.float32
self.init_group()
self.init_dilation()
self.init_test_case()
conv_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
offset = 10 * np.random.random(self.offset_size).astype(self.dtype)
mask = 10 * np.random.random(self.mask_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = dconv_im2col_gemm(input, offset, mask, filter, self.groups,
conv_param)
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Offset': OpTest.np_dtype_to_fluid_dtype(offset),
'Mask': OpTest.np_dtype_to_fluid_dtype(mask),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'deformable_groups': self.deformable_groups,
'im2col_step': self.im2col_step,
'dilations': self.dilations,
}
self.outputs = {'Output': output}
def has_cuda(self):
return core.is_compiled_with_cuda() and (self.use_cudnn or
self.use_cuda)
def test_check_output(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, {'Input', 'Offset', 'Mask', 'Filter'},
'Output',
max_relative_error=0.06)
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 8, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [8, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self):
self.groups = 1
class TestWithDilation(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [1, 1]
self.input_size = [4, 3, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [2, 2]
class TestWith3x3(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
def test_error(self):
def test_invalid_input():
paddle.enable_static()
input = [1, 3, 32, 32]
offset = fluid.data(
name='offset', shape=[None, 3, 32, 32], dtype='float32')
mask = fluid.data(
name='mask', shape=[None, 3, 32, 32], dtype='float32')
loss = fluid.layers.deformable_conv(
input, offset, mask, num_filters=4, filter_size=1)
self.assertRaises(TypeError, test_invalid_input)
def test_invalid_offset():
paddle.enable_static()
input = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='int32')
offset = fluid.data(
name='offset', shape=[None, 3, 32, 32], dtype='float32')
mask = fluid.data(
name='mask', shape=[None, 3, 32, 32], dtype='float32')
loss = fluid.layers.deformable_conv(
input, offset, mask, num_filters=4, filter_size=1)
self.assertRaises(TypeError, test_invalid_offset)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册