提交 449c7a9f 编写于 作者: L LielinJiang 提交者: whs

Make roi_perspective_transform op return mask and transform matrix (#18371)

* modify roi_perspective_transform_op to output mask and transform matrix

* modify comment

* modify comment

* modify API.spec

* update API.spec

* remove no use header, test=develop

* resolve conflict
上级 99659a96
......@@ -353,7 +353,7 @@ paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits',
paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595'))
paddle.fluid.layers.sigmoid_focal_loss (ArgSpec(args=['x', 'label', 'fg_num', 'gamma', 'alpha'], varargs=None, keywords=None, defaults=(2, 0.25)), ('document', 'aeac6aae100173b3fc7f102cf3023a3d'))
paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '0aaacaf9858b8270a8ab5b0aacdd94b7'))
paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc'))
paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', '54e3bf70e3bdbd58b3b9b65b3c69a854'))
paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', '69def376b42ef0681d0cc7f53a2dac4b'))
paddle.fluid.layers.generate_proposals (ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)), ('document', 'b7d707822b6af2a586bce608040235b1'))
paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None), ('document', 'b319b10ddaf17fb4ddf03518685a17ef'))
......
......@@ -243,7 +243,9 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* mask = ctx.Output<framework::Tensor>("Mask");
auto* out_transform_matrix =
ctx.Output<framework::Tensor>("TransformMatrix");
auto transformed_height = ctx.Attr<int>("transformed_height");
auto transformed_width = ctx.Attr<int>("transformed_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
......@@ -255,6 +257,7 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
int rois_num = rois->dims()[0];
const T* input_data = in->data<T>();
int* mask_data = mask->mutable_data<int>(ctx.GetPlace());
framework::Tensor roi2image;
roi2image.Resize({rois_num});
......@@ -279,7 +282,8 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
}
int image_id = roi2image_data[n];
// Get transform matrix
T transform_matrix[9];
T* transform_matrix =
out_transform_matrix->mutable_data<T>({9}, ctx.GetPlace());
get_transform_matrix<T>(transformed_width, transformed_height, roi_x,
roi_y, transform_matrix);
......@@ -298,13 +302,19 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
GT<T>(-0.5, in_h) ||
GT<T>(in_h, static_cast<T>(in_height - 0.5))) {
output_data[out_index] = 0.0;
mask_data[(n * transformed_height + out_h) * transformed_width +
out_w] = 0;
} else {
bilinear_interpolate(input_data, channels, in_width, in_height,
image_id, c, in_w, in_h,
output_data + out_index);
mask_data[(n * transformed_height + out_h) * transformed_width +
out_w] = 1;
}
} else {
output_data[out_index] = 0.0;
mask_data[(n * transformed_height + out_h) * transformed_width +
out_w] = 0;
}
}
}
......@@ -467,7 +477,6 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
"Output(Out) of ROIPerspectiveTransformOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");
PADDLE_ENFORCE(input_dims.size() == 4,
"The format of input tensor is NCHW.");
PADDLE_ENFORCE(rois_dims.size() == 2,
......@@ -476,7 +485,6 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(rois_dims[1] == 8,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 8)"
"given as [[x0, y0, x1, y1, x2, y2, x3, y3], ...].");
int transformed_height = ctx->Attrs().Get<int>("transformed_height");
int transformed_width = ctx->Attrs().Get<int>("transformed_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
......@@ -493,7 +501,18 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
static_cast<int64_t>(transformed_width)});
auto out_dims = framework::make_ddim(out_dims_v);
std::vector<int64_t> mask_dims_v({rois_dims[0], // num_rois
1, // channels
static_cast<int64_t>(transformed_height),
static_cast<int64_t>(transformed_width)});
auto mask_dims = framework::make_ddim(mask_dims_v);
std::vector<int64_t> matrix_dims_v(9);
auto matrix_dims = framework::make_ddim(matrix_dims_v);
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("Mask", mask_dims);
ctx->SetOutputDim("TransformMatrix", matrix_dims);
ctx->SetOutputDim("Out2InIdx", out_dims);
ctx->SetOutputDim("Out2InWeights", out_dims);
ctx->ShareLoD("ROIs", /*->*/ "Out");
......@@ -552,6 +571,16 @@ class ROIPerspectiveTransformOpMaker
"(Tensor), "
"The output of ROIPerspectiveTransformOp is a 4-D tensor with shape "
"(num_rois, channels, transformed_h, transformed_w).");
AddOutput("Mask",
"(Tensor), "
"The output mask of ROIPerspectiveTransformOp is a 4-D tensor "
"with shape "
"(num_rois, 1, transformed_h, transformed_w).");
AddOutput("TransformMatrix",
"(Tensor), "
"The output transform matrix of ROIPerspectiveTransformOp is a "
"1-D tensor with shape "
"(9,).");
AddOutput("Out2InIdx",
"(Tensor), "
"An intermediate tensor used to map indexes of input feature map "
......
......@@ -278,7 +278,7 @@ __global__ void RoiTransformKernel(
const float* input_data, const float* rois_data, const int* roi2image_data,
int num_rois, int in_height, int in_width, int channels,
int transformed_height, int transformed_width, float spatial_scale,
T* output_data, int* out2in_idx, T* out2in_w) {
T* output_data, int* out2in_idx, T* out2in_w, int* mask, T* matrix) {
int output_size =
num_rois * transformed_height * transformed_width * channels;
......@@ -303,7 +303,6 @@ __global__ void RoiTransformKernel(
}
// Get transform matrix
T matrix[9];
get_transform_matrix<T>(transformed_width, transformed_height, roi_x, roi_y,
matrix);
......@@ -317,17 +316,20 @@ __global__ void RoiTransformKernel(
GT<T>(-0.5, in_h) || GT<T>(in_h, static_cast<T>(in_height - 0.5))) {
// Skip if source coords is not in input image
output_data[index] = 0.0;
mask[(n * transformed_height + out_h) * transformed_width + out_w] = 0;
} else {
// Perform bilinear interpolation
int in_n = roi2image_data[n];
bilinear_interpolate<T>(input_data, channels, in_width, in_height, in_n,
c, in_w, in_h, output_data + index, index,
out2in_idx, out2in_w);
mask[(n * transformed_height + out_h) * transformed_width + out_w] = 1;
}
} else {
// Skip if source coords is not in quad
output_data[index] = 0.0;
mask[(n * transformed_height + out_h) * transformed_width + out_w] = 0;
}
}
}
......@@ -341,7 +343,11 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::Tensor>("Out");
auto* out2in_idx = ctx.Output<framework::Tensor>("Out2InIdx");
auto* out2in_w = ctx.Output<framework::Tensor>("Out2InWeights");
auto* mask = ctx.Output<framework::Tensor>("Mask");
auto* out_transform_matrix =
ctx.Output<framework::Tensor>("TransformMatrix");
int* mask_data = mask->mutable_data<int>(ctx.GetPlace());
int* out2in_idx_data =
out2in_idx->mutable_data<int>({out->numel(), 4}, ctx.GetPlace());
T* out2in_w_data =
......@@ -382,10 +388,14 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
int block = 512;
int grid = (out_size + block - 1) / block;
// Get transform matrix
T* matrix = out_transform_matrix->mutable_data<T>({9}, ctx.GetPlace());
RoiTransformKernel<T><<<grid, block, 0, stream>>>(
input_data, rois_data, roi2image_dev.data<int>(), rois_num, in_height,
in_width, channels, transformed_height, transformed_width,
spatial_scale, output_data, out2in_idx_data, out2in_w_data);
spatial_scale, output_data, out2in_idx_data, out2in_w_data, mask_data,
matrix);
}
};
......
......@@ -2099,8 +2099,16 @@ def roi_perspective_transform(input,
spatial_scale (float): Spatial scale factor to scale ROI coords. Default: 1.0
Returns:
Variable: The output of ROIPerspectiveTransformOp which is a 4-D tensor with shape
(num_rois, channels, transformed_h, transformed_w).
tuple: A tuple with three Variables. (out, mask, transform_matrix)
out: The output of ROIPerspectiveTransformOp which is a 4-D tensor with shape
(num_rois, channels, transformed_h, transformed_w).
mask: The mask of ROIPerspectiveTransformOp which is a 4-D tensor with shape
(num_rois, 1, transformed_h, transformed_w).
transform_matrix: The transform matrix of ROIPerspectiveTransformOp which is
a 1-D tensor with shape (9,).
Examples:
.. code-block:: python
......@@ -2109,11 +2117,13 @@ def roi_perspective_transform(input,
x = fluid.layers.data(name='x', shape=[256, 28, 28], dtype='float32')
rois = fluid.layers.data(name='rois', shape=[8], lod_level=1, dtype='float32')
out = fluid.layers.roi_perspective_transform(x, rois, 7, 7, 1.0)
out, mask, transform_matrix = fluid.layers.roi_perspective_transform(x, rois, 7, 7, 1.0)
"""
helper = LayerHelper('roi_perspective_transform', **locals())
dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype)
mask = helper.create_variable_for_type_inference(dtype="int32")
transform_matrix = helper.create_variable_for_type_inference(dtype)
out2in_idx = helper.create_variable_for_type_inference(dtype="int32")
out2in_w = helper.create_variable_for_type_inference(dtype)
helper.append_op(
......@@ -2123,14 +2133,16 @@ def roi_perspective_transform(input,
outputs={
"Out": out,
"Out2InIdx": out2in_idx,
"Out2InWeights": out2in_w
"Out2InWeights": out2in_w,
"Mask": mask,
"TransformMatrix": transform_matrix
},
attrs={
"transformed_height": transformed_height,
"transformed_width": transformed_width,
"spatial_scale": spatial_scale
})
return out
return out, mask, transform_matrix
def generate_proposal_labels(rpn_rois,
......
......@@ -198,7 +198,8 @@ def roi_transform(in_data, rois, rois_lod, transformed_height,
roi2image[j] = i
out = np.zeros([rois_num, channels, transformed_height, transformed_width])
mask = np.zeros(
[rois_num, 1, transformed_height, transformed_width]).astype('int')
for n in range(rois_num):
roi_x = []
roi_y = []
......@@ -219,9 +220,11 @@ def roi_transform(in_data, rois, rois_lod, transformed_height,
in_h, -0.5) and lt_e(in_h, in_height - 0.5):
out[n][c][out_h][out_w] = bilinear_interpolate(
in_data, image_id, c, in_w, in_h)
mask[n][0][out_h][out_w] = 1
else:
out[n][c][out_h][out_w] = 0.0
return out.astype("float32")
mask[n][0][out_h][out_w] = 0
return out.astype("float32"), mask, transform_matrix
class TestROIPoolOp(OpTest):
......@@ -236,10 +239,14 @@ class TestROIPoolOp(OpTest):
'transformed_height': self.transformed_height,
'transformed_width': self.transformed_width
}
out = roi_transform(self.x, self.rois, self.rois_lod,
self.transformed_height, self.transformed_width,
self.spatial_scale)
self.outputs = {'Out': out}
out, mask, transform_matrix = roi_transform(
self.x, self.rois, self.rois_lod, self.transformed_height,
self.transformed_width, self.spatial_scale)
self.outputs = {
'Out': out,
'Mask': mask,
'TransformMatrix': transform_matrix
}
def init_test_case(self):
self.batch_size = 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册