From 55ce36e98146d68eab61223d7119208a1517b96f Mon Sep 17 00:00:00 2001 From: whs Date: Thu, 25 Apr 2019 23:09:03 +0800 Subject: [PATCH] Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090) * Cache the information of linear interpolation in forward and use it in backward. test=develop * Fix cuda kernel. test=develop --- .../detection/roi_perspective_transform_op.cc | 18 +++ .../detection/roi_perspective_transform_op.cu | 135 +++++++----------- python/paddle/fluid/layers/detection.py | 8 +- .../test_roi_perspective_transform_op.py | 4 + 4 files changed, 78 insertions(+), 87 deletions(-) diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index 5b84221cfa..54dd28c986 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -494,6 +494,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { auto out_dims = framework::make_ddim(out_dims_v); ctx->SetOutputDim("Out", out_dims); + ctx->SetOutputDim("Out2InIdx", out_dims); + ctx->SetOutputDim("Out2InWeights", out_dims); ctx->ShareLoD("ROIs", /*->*/ "Out"); } @@ -550,6 +552,20 @@ class ROIPerspectiveTransformOpMaker "(Tensor), " "The output of ROIPerspectiveTransformOp is a 4-D tensor with shape " "(num_rois, channels, transformed_h, transformed_w)."); + AddOutput("Out2InIdx", + "(Tensor), " + "An intermediate tensor used to map indexes of input feature map " + "and indexes of output feature map." + "The shape of the tensor is [out_size, 4] and out_size is the " + "number of elements in output feature map.") + .AsIntermediate(); + AddOutput("Out2InWeights", + "(Tensor), " + "An intermediate tensor used to record the weights of bilinear " + "interpolatein for each element in output. The shape of the " + "tensor is [out_size, 4] and out_size is the number of elements " + "in output feature map.") + .AsIntermediate(); AddAttr("spatial_scale", "(float, default 1.0), " "Spatial scale factor to scale ROI coords.") @@ -580,6 +596,8 @@ class ROIPerspectiveTransformGradDescMaker op->SetType("roi_perspective_transform_grad"); op->SetInput("X", Input("X")); op->SetInput("ROIs", Input("ROIs")); + op->SetInput("Out2InIdx", Output("Out2InIdx")); + op->SetInput("Out2InWeights", Output("Out2InWeights")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetAttrMap(Attrs()); diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu index 862d664d42..74c8384e1e 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" @@ -115,8 +116,9 @@ __device__ bool in_quad(T x, T y, T roi_x[], T roi_y[]) { template __device__ void bilinear_interpolate(const T* in_data, const int channels, const int width, const int height, - int in_n, int in_c, T in_w, T in_h, - T* val) { + int in_n, int in_c, T in_w, T in_h, T* val, + int out_idx, int* out2in_idx, + T* out2in_w) { // Deal with cases that source coords are out of feature map boundary if (GT(-0.5, in_w) || GT(in_w, width - 0.5) || GT(-0.5, in_h) || GT(in_h, height - 0.5)) { @@ -165,6 +167,16 @@ __device__ void bilinear_interpolate(const T* in_data, const int channels, T w3 = w_floor * h_floor; T w4 = w_floor * h_ceil; val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + + int base_idx = (in_n * channels + in_c) * height * width; + out2in_idx[out_idx * 4] = base_idx + in_h_floor * width + in_w_floor; + out2in_idx[out_idx * 4 + 1] = base_idx + in_h_ceil * width + in_w_floor; + out2in_idx[out_idx * 4 + 2] = base_idx + in_h_ceil * width + in_w_ceil; + out2in_idx[out_idx * 4 + 3] = base_idx + in_h_floor * width + in_w_ceil; + out2in_w[out_idx * 4] = w1; + out2in_w[out_idx * 4 + 1] = w2; + out2in_w[out_idx * 4 + 2] = w3; + out2in_w[out_idx * 4 + 3] = w4; } /** @@ -262,13 +274,11 @@ __device__ void get_transform_matrix(const int transformed_width, } template -__global__ void RoiTransformKernel(const float* input_data, - const float* rois_data, - const int* roi2image_data, int num_rois, - int in_height, int in_width, int channels, - int transformed_height, - int transformed_width, float spatial_scale, - T* output_data) { +__global__ void RoiTransformKernel( + const float* input_data, const float* rois_data, const int* roi2image_data, + int num_rois, int in_height, int in_width, int channels, + int transformed_height, int transformed_width, float spatial_scale, + T* output_data, int* out2in_idx, T* out2in_w) { int output_size = num_rois * transformed_height * transformed_width * channels; @@ -311,7 +321,8 @@ __global__ void RoiTransformKernel(const float* input_data, // Perform bilinear interpolation int in_n = roi2image_data[n]; bilinear_interpolate(input_data, channels, in_width, in_height, in_n, - c, in_w, in_h, output_data + index); + c, in_w, in_h, output_data + index, index, + out2in_idx, out2in_w); } } else { @@ -328,6 +339,16 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel { auto* in = ctx.Input("X"); auto* rois = ctx.Input("ROIs"); auto* out = ctx.Output("Out"); + auto* out2in_idx = ctx.Output("Out2InIdx"); + auto* out2in_w = ctx.Output("Out2InWeights"); + + int* out2in_idx_data = + out2in_idx->mutable_data({out->numel(), 4}, ctx.GetPlace()); + T* out2in_w_data = + out2in_w->mutable_data({out->numel(), 4}, ctx.GetPlace()); + + math::SetConstant init; + init(ctx.cuda_device_context(), out2in_idx, static_cast(-1)); auto transformed_height = ctx.Attr("transformed_height"); auto transformed_width = ctx.Attr("transformed_width"); @@ -364,7 +385,7 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel { RoiTransformKernel<<>>( input_data, rois_data, roi2image_dev.data(), rois_num, in_height, in_width, channels, transformed_height, transformed_width, - spatial_scale, output_data); + spatial_scale, output_data, out2in_idx_data, out2in_w_data); } }; @@ -420,60 +441,17 @@ __device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width, } template -__global__ void RoiTransformGradKernel( - const size_t* lod, const T* rois_data, int batch_size, int num_rois, - int in_height, int in_width, int channels, int transformed_height, - int transformed_width, float spatial_scale, const T* out_grad_data, - T* in_grad_data) { - int input_size = batch_size * in_height * in_width * channels; - - CUDA_1D_KERNEL_LOOP(index, input_size) { - // (n, c, h, w) coords in input - int in_w = idx4_4(index, batch_size, channels, in_height, in_width); - int in_h = idx4_3(index, batch_size, channels, in_height, in_width); - int c = idx4_2(index, batch_size, channels, in_height, in_width); - int n = idx4_1(index, batch_size, channels, in_height, in_width); - - T gradient = 0.0; - // Accumulate gradient over all RoIs that interpolated this element - for (size_t roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) { - const T* rois = rois_data + roi_idx * 8; - T roi_x[4]; - T roi_y[4]; - for (int k = 0; k < 4; ++k) { - roi_x[k] = rois[2 * k] * spatial_scale; - roi_y[k] = rois[2 * k + 1] * spatial_scale; - } - - // Get transform matrix - T matrix[9]; - get_transform_matrix(transformed_width, transformed_height, roi_x, - roi_y, matrix); - - const T* out_grad_ptr = - out_grad_data + - (roi_idx * channels + c) * transformed_height * transformed_width; - for (int out_h = 0; out_h < transformed_height; ++out_h) { - for (int out_w = 0; out_w < transformed_width; ++out_w) { - T src_w; - T src_h; - get_source_coords(matrix, out_w, out_h, &src_w, &src_h); - if (in_quad(src_w, src_h, roi_x, roi_y)) { - if (GT(-0.5, src_w) || - GT(src_w, static_cast(in_width - 0.5)) || - GT(-0.5, src_h) || - GT(src_h, static_cast(in_height - 0.5))) { - continue; - } - T weight = get_feature_gradient(src_w, src_h, in_w, in_h, - in_width, in_height); - gradient += - out_grad_ptr[out_h * transformed_width + out_w] * weight; - } - } - } +__global__ void RoiTransformGradKernel(int out_size, const int* out2in_idx_data, + const T* out2in_w_data, + const T* out_grad_data, + T* in_grad_data) { + CUDA_1D_KERNEL_LOOP(index, out_size * 4) { + int in_idx = out2in_idx_data[index]; + if (in_idx >= 0) { + int out_idx = index / 4; + atomicAdd(in_grad_data + in_idx, + out_grad_data[out_idx] * out2in_w_data[index]); } - in_grad_data[index] = gradient; } } @@ -481,39 +459,24 @@ template class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); + auto* out2in_idx = ctx.Input("Out2InIdx"); + auto* out2in_w = ctx.Input("Out2InWeights"); auto* out_grad = ctx.Input(framework::GradVarName("Out")); auto* in_grad = ctx.Output(framework::GradVarName("X")); - auto transformed_height = ctx.Attr("transformed_height"); - auto transformed_width = ctx.Attr("transformed_width"); - auto spatial_scale = ctx.Attr("spatial_scale"); - - auto in_dims = in->dims(); - int batch_size = in_dims[0]; - int channels = in_dims[1]; - int in_height = in_dims[2]; - int in_width = in_dims[3]; - int rois_num = rois->dims()[0]; - T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); const T* out_grad_data = out_grad->data(); - const T* rois_data = rois->data(); - - auto lod = rois->lod().back(); - auto lod_data = lod.CUDAData(ctx.GetPlace()); + const int* out2in_idx_data = out2in_idx->data(); + const T* out2in_w_data = out2in_w->data(); - int in_size = in->numel(); + int out_size = out_grad->numel(); auto stream = ctx.cuda_device_context().stream(); int block = 512; - int grid = (in_size + block - 1) / block; + int grid = (out_size * 4 + block - 1) / block; RoiTransformGradKernel<<>>( - lod_data, rois_data, batch_size, rois_num, in_height, in_width, - channels, transformed_height, transformed_width, spatial_scale, - out_grad_data, in_grad_data); + out_size, out2in_idx_data, out2in_w_data, out_grad_data, in_grad_data); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 09a573e838..3da5bedd39 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -1827,11 +1827,17 @@ def roi_perspective_transform(input, helper = LayerHelper('roi_perspective_transform', **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) + out2in_idx = helper.create_variable_for_type_inference(dtype="int32") + out2in_w = helper.create_variable_for_type_inference(dtype) helper.append_op( type="roi_perspective_transform", inputs={"X": input, "ROIs": rois}, - outputs={"Out": out}, + outputs={ + "Out": out, + "Out2InIdx": out2in_idx, + "Out2InWeights": out2in_w + }, attrs={ "transformed_height": transformed_height, "transformed_width": transformed_width, diff --git a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py index de67513156..90c5e210a2 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py @@ -299,6 +299,10 @@ class TestROIPoolOp(OpTest): self.check_output() def test_check_grad(self): + self.outputs['Out2InIdx'] = np.zeros( + [np.product(self.outputs['Out'].shape), 4]).astype("int32") + self.outputs['Out2InWeights'] = np.zeros( + [np.product(self.outputs['Out'].shape), 4]).astype("float32") self.check_grad(['X'], 'Out') -- GitLab