未验证 提交 eee6b3a7 编写于 作者: R Rayman 提交者: GitHub

speed_up for deformable conv (#46997)

上级 a6a2618e
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/deformable_conv_functor.h" #include "paddle/phi/kernels/funcs/deformable_conv_functor.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
namespace phi { namespace phi {
...@@ -38,6 +39,11 @@ void DeformableConvKernel(const Context& dev_ctx, ...@@ -38,6 +39,11 @@ void DeformableConvKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
const int batch_size = static_cast<int>(x.dims()[0]); const int batch_size = static_cast<int>(x.dims()[0]);
int temp_step = std::min(64, batch_size);
if (batch_size % temp_step == 0) {
im2col_step = temp_step;
}
std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(phi::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(phi::vectorize(out->dims())); std::vector<int64_t> output_shape_vec(phi::vectorize(out->dims()));
...@@ -101,8 +107,11 @@ void DeformableConvKernel(const Context& dev_ctx, ...@@ -101,8 +107,11 @@ void DeformableConvKernel(const Context& dev_ctx,
dilations, dilations,
deformable_groups, deformable_groups,
col_buffer_ptr); col_buffer_ptr);
DenseTensor output_3d = output_4d.Slice(i, i + 1).Resize( DenseTensor output_3d = output_4d.Slice(i, i + 1).Resize(phi::slice_ddim(
phi::slice_ddim(output_4d.dims(), 1, output_4d.dims().size())); output_4d.dims(),
1,
output_4d.dims().size())); // group * C/group * (im2step * H * W)
// get the product of pixel and weight // get the product of pixel and weight
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
DenseTensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize( DenseTensor weight_3d_slice = weight_3d.Slice(g, g + 1).Resize(
...@@ -110,8 +119,11 @@ void DeformableConvKernel(const Context& dev_ctx, ...@@ -110,8 +119,11 @@ void DeformableConvKernel(const Context& dev_ctx,
DenseTensor col_buffer_3d_slice = DenseTensor col_buffer_3d_slice =
col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim( col_buffer_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); col_buffer_3d.dims(), 1, col_buffer_3d.dims().size()));
DenseTensor output_3d_slice = output_3d.Slice(g, g + 1).Resize( DenseTensor output_3d_slice =
phi::slice_ddim(output_3d.dims(), 1, output_3d.dims().size())); output_3d.Slice(g, g + 1).Resize(phi::slice_ddim(
output_3d.dims(),
1,
output_3d.dims().size())); // C * ((im2col_step)*H*W))
blas.MatMul(weight_3d_slice, blas.MatMul(weight_3d_slice,
false, false,
col_buffer_3d_slice, col_buffer_3d_slice,
...@@ -121,7 +133,29 @@ void DeformableConvKernel(const Context& dev_ctx, ...@@ -121,7 +133,29 @@ void DeformableConvKernel(const Context& dev_ctx,
T(0.0)); T(0.0));
} }
} }
// swap axis to get the right result when im2col_step is greater than 1
if (im2col_step > 1) {
std::vector<int> axis(4);
axis[0] = 0;
axis[1] = 2;
axis[2] = 1;
axis[3] = 3;
DenseTensor real_output_buffer = phi::Transpose<T, Context>(
dev_ctx,
output_4d.Resize(
phi::make_ddim({batch_size / im2col_step,
output_shape_vec[1],
im2col_step,
output_shape_vec[2] * output_shape_vec[3]})),
axis);
out->ShareDataWith(real_output_buffer)
.Resize(phi::make_ddim(output_shape_vec));
} else {
out->ShareDataWith(output_buffer).Resize(phi::make_ddim(output_shape_vec)); out->ShareDataWith(output_buffer).Resize(phi::make_ddim(output_shape_vec));
}
} }
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册