未验证 提交 55ce36e9 编写于 作者: W whs 提交者: GitHub

Speedup roi_perspective_transform op by caching the information of linear...

Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090)

* Cache the information of linear interpolation in forward and use it in backward.
test=develop

* Fix cuda kernel.
test=develop
上级 842ded14
...@@ -494,6 +494,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { ...@@ -494,6 +494,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
auto out_dims = framework::make_ddim(out_dims_v); auto out_dims = framework::make_ddim(out_dims_v);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("Out2InIdx", out_dims);
ctx->SetOutputDim("Out2InWeights", out_dims);
ctx->ShareLoD("ROIs", /*->*/ "Out"); ctx->ShareLoD("ROIs", /*->*/ "Out");
} }
...@@ -550,6 +552,20 @@ class ROIPerspectiveTransformOpMaker ...@@ -550,6 +552,20 @@ class ROIPerspectiveTransformOpMaker
"(Tensor), " "(Tensor), "
"The output of ROIPerspectiveTransformOp is a 4-D tensor with shape " "The output of ROIPerspectiveTransformOp is a 4-D tensor with shape "
"(num_rois, channels, transformed_h, transformed_w)."); "(num_rois, channels, transformed_h, transformed_w).");
AddOutput("Out2InIdx",
"(Tensor), "
"An intermediate tensor used to map indexes of input feature map "
"and indexes of output feature map."
"The shape of the tensor is [out_size, 4] and out_size is the "
"number of elements in output feature map.")
.AsIntermediate();
AddOutput("Out2InWeights",
"(Tensor), "
"An intermediate tensor used to record the weights of bilinear "
"interpolatein for each element in output. The shape of the "
"tensor is [out_size, 4] and out_size is the number of elements "
"in output feature map.")
.AsIntermediate();
AddAttr<float>("spatial_scale", AddAttr<float>("spatial_scale",
"(float, default 1.0), " "(float, default 1.0), "
"Spatial scale factor to scale ROI coords.") "Spatial scale factor to scale ROI coords.")
...@@ -580,6 +596,8 @@ class ROIPerspectiveTransformGradDescMaker ...@@ -580,6 +596,8 @@ class ROIPerspectiveTransformGradDescMaker
op->SetType("roi_perspective_transform_grad"); op->SetType("roi_perspective_transform_grad");
op->SetInput("X", Input("X")); op->SetInput("X", Input("X"));
op->SetInput("ROIs", Input("ROIs")); op->SetInput("ROIs", Input("ROIs"));
op->SetInput("Out2InIdx", Output("Out2InIdx"));
op->SetInput("Out2InWeights", Output("Out2InWeights"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -115,8 +116,9 @@ __device__ bool in_quad(T x, T y, T roi_x[], T roi_y[]) { ...@@ -115,8 +116,9 @@ __device__ bool in_quad(T x, T y, T roi_x[], T roi_y[]) {
template <typename T> template <typename T>
__device__ void bilinear_interpolate(const T* in_data, const int channels, __device__ void bilinear_interpolate(const T* in_data, const int channels,
const int width, const int height, const int width, const int height,
int in_n, int in_c, T in_w, T in_h, int in_n, int in_c, T in_w, T in_h, T* val,
T* val) { int out_idx, int* out2in_idx,
T* out2in_w) {
// Deal with cases that source coords are out of feature map boundary // Deal with cases that source coords are out of feature map boundary
if (GT<T>(-0.5, in_w) || GT<T>(in_w, width - 0.5) || GT<T>(-0.5, in_h) || if (GT<T>(-0.5, in_w) || GT<T>(in_w, width - 0.5) || GT<T>(-0.5, in_h) ||
GT<T>(in_h, height - 0.5)) { GT<T>(in_h, height - 0.5)) {
...@@ -165,6 +167,16 @@ __device__ void bilinear_interpolate(const T* in_data, const int channels, ...@@ -165,6 +167,16 @@ __device__ void bilinear_interpolate(const T* in_data, const int channels,
T w3 = w_floor * h_floor; T w3 = w_floor * h_floor;
T w4 = w_floor * h_ceil; T w4 = w_floor * h_ceil;
val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; val[0] = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
int base_idx = (in_n * channels + in_c) * height * width;
out2in_idx[out_idx * 4] = base_idx + in_h_floor * width + in_w_floor;
out2in_idx[out_idx * 4 + 1] = base_idx + in_h_ceil * width + in_w_floor;
out2in_idx[out_idx * 4 + 2] = base_idx + in_h_ceil * width + in_w_ceil;
out2in_idx[out_idx * 4 + 3] = base_idx + in_h_floor * width + in_w_ceil;
out2in_w[out_idx * 4] = w1;
out2in_w[out_idx * 4 + 1] = w2;
out2in_w[out_idx * 4 + 2] = w3;
out2in_w[out_idx * 4 + 3] = w4;
} }
/** /**
...@@ -262,13 +274,11 @@ __device__ void get_transform_matrix(const int transformed_width, ...@@ -262,13 +274,11 @@ __device__ void get_transform_matrix(const int transformed_width,
} }
template <typename T> template <typename T>
__global__ void RoiTransformKernel(const float* input_data, __global__ void RoiTransformKernel(
const float* rois_data, const float* input_data, const float* rois_data, const int* roi2image_data,
const int* roi2image_data, int num_rois, int num_rois, int in_height, int in_width, int channels,
int in_height, int in_width, int channels, int transformed_height, int transformed_width, float spatial_scale,
int transformed_height, T* output_data, int* out2in_idx, T* out2in_w) {
int transformed_width, float spatial_scale,
T* output_data) {
int output_size = int output_size =
num_rois * transformed_height * transformed_width * channels; num_rois * transformed_height * transformed_width * channels;
...@@ -311,7 +321,8 @@ __global__ void RoiTransformKernel(const float* input_data, ...@@ -311,7 +321,8 @@ __global__ void RoiTransformKernel(const float* input_data,
// Perform bilinear interpolation // Perform bilinear interpolation
int in_n = roi2image_data[n]; int in_n = roi2image_data[n];
bilinear_interpolate<T>(input_data, channels, in_width, in_height, in_n, bilinear_interpolate<T>(input_data, channels, in_width, in_height, in_n,
c, in_w, in_h, output_data + index); c, in_w, in_h, output_data + index, index,
out2in_idx, out2in_w);
} }
} else { } else {
...@@ -328,6 +339,16 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> { ...@@ -328,6 +339,16 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs"); auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
auto* out2in_idx = ctx.Output<framework::Tensor>("Out2InIdx");
auto* out2in_w = ctx.Output<framework::Tensor>("Out2InWeights");
int* out2in_idx_data =
out2in_idx->mutable_data<int>({out->numel(), 4}, ctx.GetPlace());
T* out2in_w_data =
out2in_w->mutable_data<T>({out->numel(), 4}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, int> init;
init(ctx.cuda_device_context(), out2in_idx, static_cast<int>(-1));
auto transformed_height = ctx.Attr<int>("transformed_height"); auto transformed_height = ctx.Attr<int>("transformed_height");
auto transformed_width = ctx.Attr<int>("transformed_width"); auto transformed_width = ctx.Attr<int>("transformed_width");
...@@ -364,7 +385,7 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> { ...@@ -364,7 +385,7 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
RoiTransformKernel<T><<<grid, block, 0, stream>>>( RoiTransformKernel<T><<<grid, block, 0, stream>>>(
input_data, rois_data, roi2image_dev.data<int>(), rois_num, in_height, input_data, rois_data, roi2image_dev.data<int>(), rois_num, in_height,
in_width, channels, transformed_height, transformed_width, in_width, channels, transformed_height, transformed_width,
spatial_scale, output_data); spatial_scale, output_data, out2in_idx_data, out2in_w_data);
} }
}; };
...@@ -420,100 +441,42 @@ __device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width, ...@@ -420,100 +441,42 @@ __device__ T get_feature_gradient(T xs, T ys, int w, int h, const int width,
} }
template <typename T> template <typename T>
__global__ void RoiTransformGradKernel( __global__ void RoiTransformGradKernel(int out_size, const int* out2in_idx_data,
const size_t* lod, const T* rois_data, int batch_size, int num_rois, const T* out2in_w_data,
int in_height, int in_width, int channels, int transformed_height, const T* out_grad_data,
int transformed_width, float spatial_scale, const T* out_grad_data,
T* in_grad_data) { T* in_grad_data) {
int input_size = batch_size * in_height * in_width * channels; CUDA_1D_KERNEL_LOOP(index, out_size * 4) {
int in_idx = out2in_idx_data[index];
CUDA_1D_KERNEL_LOOP(index, input_size) { if (in_idx >= 0) {
// (n, c, h, w) coords in input int out_idx = index / 4;
int in_w = idx4_4(index, batch_size, channels, in_height, in_width); atomicAdd(in_grad_data + in_idx,
int in_h = idx4_3(index, batch_size, channels, in_height, in_width); out_grad_data[out_idx] * out2in_w_data[index]);
int c = idx4_2(index, batch_size, channels, in_height, in_width);
int n = idx4_1(index, batch_size, channels, in_height, in_width);
T gradient = 0.0;
// Accumulate gradient over all RoIs that interpolated this element
for (size_t roi_idx = lod[n]; roi_idx < lod[n + 1]; ++roi_idx) {
const T* rois = rois_data + roi_idx * 8;
T roi_x[4];
T roi_y[4];
for (int k = 0; k < 4; ++k) {
roi_x[k] = rois[2 * k] * spatial_scale;
roi_y[k] = rois[2 * k + 1] * spatial_scale;
}
// Get transform matrix
T matrix[9];
get_transform_matrix<T>(transformed_width, transformed_height, roi_x,
roi_y, matrix);
const T* out_grad_ptr =
out_grad_data +
(roi_idx * channels + c) * transformed_height * transformed_width;
for (int out_h = 0; out_h < transformed_height; ++out_h) {
for (int out_w = 0; out_w < transformed_width; ++out_w) {
T src_w;
T src_h;
get_source_coords<T>(matrix, out_w, out_h, &src_w, &src_h);
if (in_quad<T>(src_w, src_h, roi_x, roi_y)) {
if (GT<T>(-0.5, src_w) ||
GT<T>(src_w, static_cast<T>(in_width - 0.5)) ||
GT<T>(-0.5, src_h) ||
GT<T>(src_h, static_cast<T>(in_height - 0.5))) {
continue;
}
T weight = get_feature_gradient<T>(src_w, src_h, in_w, in_h,
in_width, in_height);
gradient +=
out_grad_ptr[out_h * transformed_width + out_w] * weight;
} }
} }
}
}
in_grad_data[index] = gradient;
}
} }
template <typename T> template <typename T>
class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> { class CUDAROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* out2in_idx = ctx.Input<framework::LoDTensor>("Out2InIdx");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs"); auto* out2in_w = ctx.Input<framework::LoDTensor>("Out2InWeights");
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto transformed_height = ctx.Attr<int>("transformed_height");
auto transformed_width = ctx.Attr<int>("transformed_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto in_dims = in->dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int in_height = in_dims[2];
int in_width = in_dims[3];
int rois_num = rois->dims()[0];
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace()); T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
const T* rois_data = rois->data<T>(); const int* out2in_idx_data = out2in_idx->data<int>();
const T* out2in_w_data = out2in_w->data<T>();
auto lod = rois->lod().back();
auto lod_data = lod.CUDAData(ctx.GetPlace());
int in_size = in->numel(); int out_size = out_grad->numel();
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
int block = 512; int block = 512;
int grid = (in_size + block - 1) / block; int grid = (out_size * 4 + block - 1) / block;
RoiTransformGradKernel<T><<<grid, block, 0, stream>>>( RoiTransformGradKernel<T><<<grid, block, 0, stream>>>(
lod_data, rois_data, batch_size, rois_num, in_height, in_width, out_size, out2in_idx_data, out2in_w_data, out_grad_data, in_grad_data);
channels, transformed_height, transformed_width, spatial_scale,
out_grad_data, in_grad_data);
} }
}; };
......
...@@ -1827,11 +1827,17 @@ def roi_perspective_transform(input, ...@@ -1827,11 +1827,17 @@ def roi_perspective_transform(input,
helper = LayerHelper('roi_perspective_transform', **locals()) helper = LayerHelper('roi_perspective_transform', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
out2in_idx = helper.create_variable_for_type_inference(dtype="int32")
out2in_w = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="roi_perspective_transform", type="roi_perspective_transform",
inputs={"X": input, inputs={"X": input,
"ROIs": rois}, "ROIs": rois},
outputs={"Out": out}, outputs={
"Out": out,
"Out2InIdx": out2in_idx,
"Out2InWeights": out2in_w
},
attrs={ attrs={
"transformed_height": transformed_height, "transformed_height": transformed_height,
"transformed_width": transformed_width, "transformed_width": transformed_width,
......
...@@ -299,6 +299,10 @@ class TestROIPoolOp(OpTest): ...@@ -299,6 +299,10 @@ class TestROIPoolOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.outputs['Out2InIdx'] = np.zeros(
[np.product(self.outputs['Out'].shape), 4]).astype("int32")
self.outputs['Out2InWeights'] = np.zeros(
[np.product(self.outputs['Out'].shape), 4]).astype("float32")
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册