From ec64f44f0eab44044c6f4c7c780c6eb9a090e871 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Mon, 26 Aug 2019 15:07:05 +0800 Subject: [PATCH] Make roi_perspective_transform op return mask and transform matrix,test=release/1.5 (#19391) * make_roi_perspective_transform_op_return_mask_and_matrix * make_roi_perspective_transform_op_return_mask_and_matrix --- paddle/fluid/API.spec | 2 +- .../detection/roi_perspective_transform_op.cc | 47 ++++++++++++++++--- .../detection/roi_perspective_transform_op.cu | 31 +++++++++--- python/paddle/fluid/layers/detection.py | 22 +++++++-- .../test_roi_perspective_transform_op.py | 22 ++++++--- 5 files changed, 97 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 893a5130ffd..16c61a0e49b 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -353,7 +353,7 @@ paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595')) paddle.fluid.layers.sigmoid_focal_loss (ArgSpec(args=['x', 'label', 'fg_num', 'gamma', 'alpha'], varargs=None, keywords=None, defaults=(2, 0.25)), ('document', 'aeac6aae100173b3fc7f102cf3023a3d')) paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '0aaacaf9858b8270a8ab5b0aacdd94b7')) -paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc')) +paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'a82016342789ba9d85737e405f824ff1')) paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'e87c1131e98715d3657a96c44db1b910')) paddle.fluid.layers.generate_proposals (ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)), ('document', 'b7d707822b6af2a586bce608040235b1')) paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None), ('document', 'b319b10ddaf17fb4ddf03518685a17ef')) diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index 54dd28c986f..6628dde5c2f 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -243,7 +243,9 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel { auto* in = ctx.Input("X"); auto* rois = ctx.Input("ROIs"); auto* out = ctx.Output("Out"); - + auto* mask = ctx.Output("Mask"); + auto* out_transform_matrix = + ctx.Output("TransformMatrix"); auto transformed_height = ctx.Attr("transformed_height"); auto transformed_width = ctx.Attr("transformed_width"); auto spatial_scale = ctx.Attr("spatial_scale"); @@ -255,6 +257,7 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel { int rois_num = rois->dims()[0]; const T* input_data = in->data(); + int* mask_data = mask->mutable_data(ctx.GetPlace()); framework::Tensor roi2image; roi2image.Resize({rois_num}); @@ -269,6 +272,9 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel { T* output_data = out->mutable_data(ctx.GetPlace()); const T* rois_data = rois->data(); + T* transform_matrix = + out_transform_matrix->mutable_data({rois_num, 9}, ctx.GetPlace()); + for (int n = 0; n < rois_num; ++n) { const T* n_rois = rois_data + n * 8; T roi_x[4]; @@ -279,10 +285,12 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel { } int image_id = roi2image_data[n]; // Get transform matrix - T transform_matrix[9]; + T matrix[9]; get_transform_matrix(transformed_width, transformed_height, roi_x, - roi_y, transform_matrix); - + roi_y, matrix); + for (int i = 0; i < 9; i++) { + transform_matrix[n * 9 + i] = matrix[i]; + } for (int c = 0; c < channels; ++c) { for (int out_h = 0; out_h < transformed_height; ++out_h) { for (int out_w = 0; out_w < transformed_width; ++out_w) { @@ -291,20 +299,26 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel { c * transformed_height * transformed_width + out_h * transformed_width + out_w; T in_w, in_h; - get_source_coords(transform_matrix, out_w, out_h, &in_w, &in_h); + get_source_coords(matrix, out_w, out_h, &in_w, &in_h); if (in_quad(in_w, in_h, roi_x, roi_y)) { if (GT(-0.5, in_w) || GT(in_w, static_cast(in_width - 0.5)) || GT(-0.5, in_h) || GT(in_h, static_cast(in_height - 0.5))) { output_data[out_index] = 0.0; + mask_data[(n * transformed_height + out_h) * transformed_width + + out_w] = 0; } else { bilinear_interpolate(input_data, channels, in_width, in_height, image_id, c, in_w, in_h, output_data + out_index); + mask_data[(n * transformed_height + out_h) * transformed_width + + out_w] = 1; } } else { output_data[out_index] = 0.0; + mask_data[(n * transformed_height + out_h) * transformed_width + + out_w] = 0; } } } @@ -467,7 +481,6 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { "Output(Out) of ROIPerspectiveTransformOp should not be null."); auto input_dims = ctx->GetInputDim("X"); auto rois_dims = ctx->GetInputDim("ROIs"); - PADDLE_ENFORCE(input_dims.size() == 4, "The format of input tensor is NCHW."); PADDLE_ENFORCE(rois_dims.size() == 2, @@ -476,7 +489,6 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(rois_dims[1] == 8, "ROIs should be a 2-D LoDTensor of shape (num_rois, 8)" "given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...]."); - int transformed_height = ctx->Attrs().Get("transformed_height"); int transformed_width = ctx->Attrs().Get("transformed_width"); float spatial_scale = ctx->Attrs().Get("spatial_scale"); @@ -493,7 +505,18 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { static_cast(transformed_width)}); auto out_dims = framework::make_ddim(out_dims_v); + std::vector mask_dims_v({rois_dims[0], // num_rois + 1, // channels + static_cast(transformed_height), + static_cast(transformed_width)}); + auto mask_dims = framework::make_ddim(mask_dims_v); + + std::vector matrix_dims_v({rois_dims[0], 9}); + auto matrix_dims = framework::make_ddim(matrix_dims_v); + ctx->SetOutputDim("Out", out_dims); + ctx->SetOutputDim("Mask", mask_dims); + ctx->SetOutputDim("TransformMatrix", matrix_dims); ctx->SetOutputDim("Out2InIdx", out_dims); ctx->SetOutputDim("Out2InWeights", out_dims); ctx->ShareLoD("ROIs", /*->*/ "Out"); @@ -552,6 +575,16 @@ class ROIPerspectiveTransformOpMaker "(Tensor), " "The output of ROIPerspectiveTransformOp is a 4-D tensor with shape " "(num_rois, channels, transformed_h, transformed_w)."); + AddOutput("Mask", + "(Tensor), " + "The output mask of ROIPerspectiveTransformOp is a 4-D tensor " + "with shape " + "(num_rois, 1, transformed_h, transformed_w)."); + AddOutput("TransformMatrix", + "(Tensor), " + "The output transform matrix of ROIPerspectiveTransformOp is a " + "1-D tensor with shape " + "(num_rois, 9)."); AddOutput("Out2InIdx", "(Tensor), " "An intermediate tensor used to map indexes of input feature map " diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu index 85eb0c45e06..19df68faf9e 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu @@ -274,11 +274,14 @@ __device__ void get_transform_matrix(const int transformed_width, } template -__global__ void RoiTransformKernel( - const float* input_data, const float* rois_data, const int* roi2image_data, - int num_rois, int in_height, int in_width, int channels, - int transformed_height, int transformed_width, float spatial_scale, - T* output_data, int* out2in_idx, T* out2in_w) { +__global__ void RoiTransformKernel(const float* input_data, + const float* rois_data, + const int* roi2image_data, int num_rois, + int in_height, int in_width, int channels, + int transformed_height, + int transformed_width, float spatial_scale, + T* output_data, int* out2in_idx, T* out2in_w, + int* mask, T* transform_matrix) { int output_size = num_rois * transformed_height * transformed_width * channels; @@ -306,7 +309,9 @@ __global__ void RoiTransformKernel( T matrix[9]; get_transform_matrix(transformed_width, transformed_height, roi_x, roi_y, matrix); - + for (int i = 0; i < 9; i++) { + transform_matrix[n * 9 + i] = matrix[i]; + } // Get source coords T in_w; T in_h; @@ -317,17 +322,20 @@ __global__ void RoiTransformKernel( GT(-0.5, in_h) || GT(in_h, static_cast(in_height - 0.5))) { // Skip if source coords is not in input image output_data[index] = 0.0; + mask[(n * transformed_height + out_h) * transformed_width + out_w] = 0; } else { // Perform bilinear interpolation int in_n = roi2image_data[n]; bilinear_interpolate(input_data, channels, in_width, in_height, in_n, c, in_w, in_h, output_data + index, index, out2in_idx, out2in_w); + mask[(n * transformed_height + out_h) * transformed_width + out_w] = 1; } } else { // Skip if source coords is not in quad output_data[index] = 0.0; + mask[(n * transformed_height + out_h) * transformed_width + out_w] = 0; } } } @@ -341,7 +349,11 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto* out2in_idx = ctx.Output("Out2InIdx"); auto* out2in_w = ctx.Output("Out2InWeights"); + auto* mask = ctx.Output("Mask"); + auto* out_transform_matrix = + ctx.Output("TransformMatrix"); + int* mask_data = mask->mutable_data(ctx.GetPlace()); int* out2in_idx_data = out2in_idx->mutable_data({out->numel(), 4}, ctx.GetPlace()); T* out2in_w_data = @@ -382,10 +394,15 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel { int block = 512; int grid = (out_size + block - 1) / block; + // Get transform matrix + T* matrix = + out_transform_matrix->mutable_data({rois_num, 9}, ctx.GetPlace()); + RoiTransformKernel<<>>( input_data, rois_data, roi2image_dev.data(), rois_num, in_height, in_width, channels, transformed_height, transformed_width, - spatial_scale, output_data, out2in_idx_data, out2in_w_data); + spatial_scale, output_data, out2in_idx_data, out2in_w_data, mask_data, + matrix); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 8d8b368caf2..6bb71e8991b 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -2100,8 +2100,16 @@ def roi_perspective_transform(input, spatial_scale (float): Spatial scale factor to scale ROI coords. Default: 1.0 Returns: - Variable: The output of ROIPerspectiveTransformOp which is a 4-D tensor with shape - (num_rois, channels, transformed_h, transformed_w). + tuple: A tuple with three Variables. (out, mask, transform_matrix) + + out: The output of ROIPerspectiveTransformOp which is a 4-D tensor with shape + (num_rois, channels, transformed_h, transformed_w). + + mask: The mask of ROIPerspectiveTransformOp which is a 4-D tensor with shape + (num_rois, 1, transformed_h, transformed_w). + + transform_matrix: The transform matrix of ROIPerspectiveTransformOp which is + a 2-D tensor with shape (num_rois, 9). Examples: .. code-block:: python @@ -2110,11 +2118,13 @@ def roi_perspective_transform(input, x = fluid.layers.data(name='x', shape=[256, 28, 28], dtype='float32') rois = fluid.layers.data(name='rois', shape=[8], lod_level=1, dtype='float32') - out = fluid.layers.roi_perspective_transform(x, rois, 7, 7, 1.0) + out, mask, transform_matrix = fluid.layers.roi_perspective_transform(x, rois, 7, 7, 1.0) """ helper = LayerHelper('roi_perspective_transform', **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) + mask = helper.create_variable_for_type_inference(dtype="int32") + transform_matrix = helper.create_variable_for_type_inference(dtype) out2in_idx = helper.create_variable_for_type_inference(dtype="int32") out2in_w = helper.create_variable_for_type_inference(dtype) helper.append_op( @@ -2124,14 +2134,16 @@ def roi_perspective_transform(input, outputs={ "Out": out, "Out2InIdx": out2in_idx, - "Out2InWeights": out2in_w + "Out2InWeights": out2in_w, + "Mask": mask, + "TransformMatrix": transform_matrix }, attrs={ "transformed_height": transformed_height, "transformed_width": transformed_width, "spatial_scale": spatial_scale }) - return out + return out, mask, transform_matrix def generate_proposal_labels(rpn_rois, diff --git a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py index 90c5e210a25..b56f331e908 100644 --- a/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py +++ b/python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py @@ -198,7 +198,9 @@ def roi_transform(in_data, rois, rois_lod, transformed_height, roi2image[j] = i out = np.zeros([rois_num, channels, transformed_height, transformed_width]) - + mask = np.zeros( + [rois_num, 1, transformed_height, transformed_width]).astype('int') + matrix = np.zeros([rois_num, 9], dtype=in_data.dtype) for n in range(rois_num): roi_x = [] roi_y = [] @@ -208,7 +210,7 @@ def roi_transform(in_data, rois, rois_lod, transformed_height, image_id = roi2image[n] transform_matrix = get_transform_matrix( transformed_width, transformed_height, roi_x, roi_y) - + matrix[n] = transform_matrix for c in range(channels): for out_h in range(transformed_height): for out_w in range(transformed_width): @@ -219,9 +221,11 @@ def roi_transform(in_data, rois, rois_lod, transformed_height, in_h, -0.5) and lt_e(in_h, in_height - 0.5): out[n][c][out_h][out_w] = bilinear_interpolate( in_data, image_id, c, in_w, in_h) + mask[n][0][out_h][out_w] = 1 else: out[n][c][out_h][out_w] = 0.0 - return out.astype("float32") + mask[n][0][out_h][out_w] = 0 + return out.astype("float32"), mask, matrix class TestROIPoolOp(OpTest): @@ -236,10 +240,14 @@ class TestROIPoolOp(OpTest): 'transformed_height': self.transformed_height, 'transformed_width': self.transformed_width } - out = roi_transform(self.x, self.rois, self.rois_lod, - self.transformed_height, self.transformed_width, - self.spatial_scale) - self.outputs = {'Out': out} + out, mask, transform_matrix = roi_transform( + self.x, self.rois, self.rois_lod, self.transformed_height, + self.transformed_width, self.spatial_scale) + self.outputs = { + 'Out': out, + 'Mask': mask, + 'TransformMatrix': transform_matrix + } def init_test_case(self): self.batch_size = 2 -- GitLab