From ff22a9c4686908212abc4e505ad8336b496189f4 Mon Sep 17 00:00:00 2001 From: Leo Guo <58431564+ZibinGuo@users.noreply.github.com> Date: Thu, 2 Jun 2022 22:04:41 +0800 Subject: [PATCH] Add generate_proposals_v2 op and expend function of gather op for kunlun. *test=kunlun (#43162) * Add generate_proposals_v2 op and unittest for kunlun. *test=kunlun * Add the assign op to xpu2_op_list and expand the function of gather op. Add the unit-test of generate_proposals_v2. *test=kunlun --- .../fluid/operators/detection/CMakeLists.txt | 6 +- .../detection/generate_proposals_v2_op_xpu.cc | 370 ++++++++++++ paddle/fluid/operators/gather_op_xpu.cc | 37 +- .../fluid/platform/device/xpu/xpu2_op_list.h | 7 + .../xpu/test_generate_proposals_v2_op_xpu.py | 544 ++++++++++++++++++ 5 files changed, 955 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/operators/detection/generate_proposals_v2_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_generate_proposals_v2_op_xpu.py diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index f10c8019199..99a69007aa5 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -39,12 +39,14 @@ endif() if(WITH_XPU) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_xpu.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_xpu.cc) + detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc generate_proposals_v2_op_xpu.cc) elseif(WITH_ASCEND_CL) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_npu.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu prior_box_op_npu.cc) else() detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) + # detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) endif() detection_library(bipartite_match_op SRCS bipartite_match_op.cc) @@ -81,7 +83,9 @@ if(WITH_GPU OR WITH_ROCM) detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS ${TMPDEPS}) else() detection_library(generate_proposals_op SRCS generate_proposals_op.cc) - detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) + if(NOT WITH_XPU) + detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) + endif() detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc) detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc) endif() diff --git a/paddle/fluid/operators/detection/generate_proposals_v2_op_xpu.cc b/paddle/fluid/operators/detection/generate_proposals_v2_op_xpu.cc new file mode 100644 index 00000000000..28c94668ba7 --- /dev/null +++ b/paddle/fluid/operators/detection/generate_proposals_v2_op_xpu.cc @@ -0,0 +1,370 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include +#include +#include +#include +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memory.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +namespace { +template +static void SortDescending(const platform::XPUDeviceContext &dev_ctx, + const Tensor &value, Tensor *index_out, + int pre_nms_top_n) { + auto *value_data = value.data(); + auto place = dev_ctx.GetPlace(); + auto cpu_place = platform::CPUPlace(); + + Tensor scores_slice_cpu; + scores_slice_cpu.Resize({value.numel()}); + auto *scores_slice_cpu_data = scores_slice_cpu.mutable_data(cpu_place); + + memory::Copy(cpu_place, scores_slice_cpu_data, place, value_data, + sizeof(T) * value.numel()); + + // Sort index + Tensor index_t; + int *index = index_t.mutable_data({value.numel()}, cpu_place); + for (int i = 0; i < value.numel(); ++i) { + index[i] = i; + } + auto compare = [scores_slice_cpu_data](const int64_t &i, const int64_t &j) { + return scores_slice_cpu_data[i] > scores_slice_cpu_data[j]; + }; + + if (pre_nms_top_n <= 0 || pre_nms_top_n >= value.numel()) { + std::sort(index, index + value.numel(), compare); + } else { + std::nth_element(index, index + pre_nms_top_n, index + value.numel(), + compare); + std::sort(index, index + pre_nms_top_n, compare); + index_t.Resize({pre_nms_top_n}); + } + + int *idx_out = + index_out->mutable_data({index_t.numel()}, dev_ctx.GetPlace()); + memory::Copy(place, idx_out, cpu_place, index, sizeof(T) * index_t.numel()); +} + +template +static std::pair ProposalForOneImage( + const platform::XPUDeviceContext &dev_ctx, const Tensor &im_shape, + const Tensor &anchors, const Tensor &variances, + const Tensor &bbox_deltas, // [M, 4] + const Tensor &scores, // [N, 1] + int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size, + float eta, bool pixel_offset) { + // 1. pre nms + Tensor index_sort; + SortDescending(dev_ctx, scores, &index_sort, pre_nms_top_n); + + Tensor scores_sel, bbox_sel, anchor_sel, var_sel; + scores_sel.mutable_data({index_sort.numel(), 1}, dev_ctx.GetPlace()); + bbox_sel.mutable_data({index_sort.numel(), 4}, dev_ctx.GetPlace()); + anchor_sel.mutable_data({index_sort.numel(), 4}, dev_ctx.GetPlace()); + var_sel.mutable_data({index_sort.numel(), 4}, dev_ctx.GetPlace()); + + int r = xpu::gather(dev_ctx.x_context(), scores.data(), + index_sort.data(), scores_sel.data(), + {static_cast(scores.numel()), 1}, + index_sort.numel(), 0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(gather) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + r = xpu::gather(dev_ctx.x_context(), bbox_deltas.data(), + index_sort.data(), bbox_sel.data(), + {static_cast(bbox_deltas.numel()) / 4, 4}, + index_sort.numel(), 0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(gather) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + r = xpu::gather(dev_ctx.x_context(), anchors.data(), + index_sort.data(), anchor_sel.data(), + {static_cast(anchors.numel()) / 4, 4}, + index_sort.numel(), 0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(gather) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + r = xpu::gather(dev_ctx.x_context(), variances.data(), + index_sort.data(), var_sel.data(), + {static_cast(variances.numel()) / 4, 4}, + index_sort.numel(), 0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(gather) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + int num = scores.numel(); + int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num) ? scores.numel() + : pre_nms_top_n; + scores_sel.Resize({pre_nms_num, 1}); + index_sort.Resize({pre_nms_num, 1}); + + // 2. box decode and clipping + Tensor proposals; + proposals.mutable_data({pre_nms_num, 4}, dev_ctx.GetPlace()); + + r = xpu::box_decoder(dev_ctx.x_context(), anchor_sel.data(), + var_sel.data(), bbox_sel.data(), + proposals.data(), pre_nms_num, !pixel_offset, true, + im_shape.data()); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(box_decoder) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + // 3. filter + Tensor keep_index, keep_num_t; + keep_index.mutable_data({pre_nms_num}, dev_ctx.GetPlace()); + keep_num_t.mutable_data({1}, dev_ctx.GetPlace()); + min_size = std::max(min_size, 1.0f); + r = xpu::remove_small_boxes(dev_ctx.x_context(), proposals.data(), + im_shape.data(), keep_index.data(), + keep_num_t.data(), pre_nms_num, min_size, + false, pixel_offset); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(remove_small_boxes) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + int keep_num; + const auto xpu_place = dev_ctx.GetPlace(); + memory::Copy(platform::CPUPlace(), &keep_num, xpu_place, + keep_num_t.data(), sizeof(int)); + keep_index.Resize({keep_num}); + + Tensor scores_filter, proposals_filter; + // Handle the case when there is no keep index left + if (keep_num == 0) { + phi::funcs::SetConstant set_zero; + proposals_filter.mutable_data({1, 4}, dev_ctx.GetPlace()); + scores_filter.mutable_data({1, 1}, dev_ctx.GetPlace()); + set_zero(dev_ctx, &proposals_filter, static_cast(0)); + set_zero(dev_ctx, &scores_filter, static_cast(0)); + return std::make_pair(proposals_filter, scores_filter); + } + proposals_filter.mutable_data({keep_num, 4}, dev_ctx.GetPlace()); + scores_filter.mutable_data({keep_num, 1}, dev_ctx.GetPlace()); + r = xpu::gather(dev_ctx.x_context(), proposals.data(), + keep_index.data(), proposals_filter.data(), + {pre_nms_num, 4}, keep_num, 0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(gather) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + r = xpu::gather(dev_ctx.x_context(), scores_sel.data(), + keep_index.data(), scores_filter.data(), + {pre_nms_num, 1}, keep_num, 0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(gather) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + if (nms_thresh <= 0) { + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + return std::make_pair(proposals_filter, scores_filter); + } + + // 4. nms + int nms_keep_num = 0; + r = xpu::nms(dev_ctx.x_context(), proposals_filter.data(), nullptr, + keep_index.data(), 1, 1, keep_num, -1, nms_thresh, -1, 0, + &nms_keep_num, pixel_offset); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(nms) return the" + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + if (post_nms_top_n > 0 && post_nms_top_n < nms_keep_num) { + keep_index.Resize({post_nms_top_n}); + } else { + keep_index.Resize({nms_keep_num}); + } + + Tensor scores_nms, proposals_nms; + proposals_nms.mutable_data({keep_index.numel(), 4}, dev_ctx.GetPlace()); + scores_nms.mutable_data({keep_index.numel(), 1}, dev_ctx.GetPlace()); + r = xpu::gather(dev_ctx.x_context(), proposals_filter.data(), + keep_index.data(), proposals_nms.data(), + {keep_num, 4}, keep_index.numel(), 0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(gather) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::gather(dev_ctx.x_context(), scores_filter.data(), + keep_index.data(), scores_nms.data(), + {keep_num, 1}, keep_index.numel(), 0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(gather) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + return std::make_pair(proposals_nms, scores_nms); +} +} // namespace + +template +class XPUGenerateProposalsV2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *scores = context.Input("Scores"); + auto *bbox_deltas = context.Input("BboxDeltas"); + auto *im_shape = context.Input("ImShape"); + auto anchors = GET_DATA_SAFELY(context.Input("Anchors"), "Input", + "Anchors", "GenerateProposals"); + auto variances = GET_DATA_SAFELY(context.Input("Variances"), + "Input", "Variances", "GenerateProposals"); + + auto *rpn_rois = context.Output("RpnRois"); + auto *rpn_roi_probs = context.Output("RpnRoiProbs"); + + int pre_nms_top_n = context.Attr("pre_nms_topN"); + int post_nms_top_n = context.Attr("post_nms_topN"); + float nms_thresh = context.Attr("nms_thresh"); + float min_size = context.Attr("min_size"); + float eta = context.Attr("eta"); + bool pixel_offset = context.Attr("pixel_offset"); + PADDLE_ENFORCE_GE(eta, 1., + platform::errors::InvalidArgument( + "Not support adaptive NMS. The attribute 'eta' " + "should not less than 1. But received eta=[%d]", + eta)); + + auto &dev_ctx = context.template device_context(); + + auto scores_dim = scores->dims(); + // the shape of bbox score + int num = scores_dim[0]; + int c_score = scores_dim[1]; + int h_score = scores_dim[2]; + int w_score = scores_dim[3]; + + auto bbox_dim = bbox_deltas->dims(); + int c_bbox = bbox_dim[1]; + int h_bbox = bbox_dim[2]; + int w_bbox = bbox_dim[3]; + + Tensor bbox_deltas_swap, scores_swap; + bbox_deltas_swap.mutable_data({num, h_bbox, w_bbox, c_bbox}, + dev_ctx.GetPlace()); + scores_swap.mutable_data({num, h_score, w_score, c_score}, + dev_ctx.GetPlace()); + + std::vector axis = {0, 2, 3, 1}; + int r = xpu::transpose(dev_ctx.x_context(), bbox_deltas->data(), + bbox_deltas_swap.data(), + {num, c_bbox, h_bbox, w_bbox}, axis); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(transpose) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::transpose(dev_ctx.x_context(), scores->data(), + scores_swap.data(), + {num, c_score, h_score, w_score}, axis); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(transpose) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + anchors.Resize({anchors.numel() / 4, 4}); + variances.Resize({variances.numel() / 4, 4}); + + // output + rpn_rois->mutable_data({bbox_deltas->numel() / 4, 4}, + context.GetPlace()); + rpn_roi_probs->mutable_data({scores->numel(), 1}, context.GetPlace()); + + T *rpn_rois_data = rpn_rois->data(); + T *rpn_roi_probs_data = rpn_roi_probs->data(); + + auto place = dev_ctx.GetPlace(); + auto cpu_place = platform::CPUPlace(); + + int num_proposals = 0; + std::vector offset(1, 0); + std::vector tmp_num; + + for (int64_t i = 0; i < num; ++i) { + Tensor im_shape_slice = im_shape->Slice(i, i + 1); + Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1); + Tensor scores_slice = scores_swap.Slice(i, i + 1); + + bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4}); + scores_slice.Resize({h_score * w_score * c_score, 1}); + + std::pair box_score_pair = ProposalForOneImage( + dev_ctx, im_shape_slice, anchors, variances, bbox_deltas_slice, + scores_slice, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, + eta, pixel_offset); + + Tensor &proposals = box_score_pair.first; + Tensor &scores = box_score_pair.second; + + memory::Copy(place, rpn_rois_data + num_proposals * 4, place, + proposals.data(), sizeof(T) * proposals.numel()); + memory::Copy(place, rpn_roi_probs_data + num_proposals, place, + scores.data(), sizeof(T) * scores.numel()); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + num_proposals += proposals.dims()[0]; + offset.emplace_back(num_proposals); + tmp_num.push_back(proposals.dims()[0]); + } + if (context.HasOutput("RpnRoisNum")) { + auto *rpn_rois_num = context.Output("RpnRoisNum"); + rpn_rois_num->mutable_data({num}, context.GetPlace()); + int *num_data = rpn_rois_num->data(); + memory::Copy(place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num); + rpn_rois_num->Resize({num}); + } + framework::LoD lod; + lod.emplace_back(offset); + rpn_rois->set_lod(lod); + rpn_roi_probs->set_lod(lod); + rpn_rois->Resize({num_proposals, 4}); + rpn_roi_probs->Resize({num_proposals, 1}); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL(generate_proposals_v2, + ops::XPUGenerateProposalsV2Kernel< + paddle::platform::XPUDeviceContext, float>); + +#endif // PADDLE_WITH_XPU diff --git a/paddle/fluid/operators/gather_op_xpu.cc b/paddle/fluid/operators/gather_op_xpu.cc index 6c691aa14ae..9dd8f58d242 100644 --- a/paddle/fluid/operators/gather_op_xpu.cc +++ b/paddle/fluid/operators/gather_op_xpu.cc @@ -38,9 +38,20 @@ class GatherOpXPUKernel : public framework::OpKernel { auto *x = ctx.Input("X"); auto *index = ctx.Input("Index"); auto *output = ctx.Output("Out"); + + int axis = ctx.Attr("axis"); if (ctx.HasInput("Axis")) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Now, it doesn't support XPU with Axis.")); + Tensor cpu_axis; + const Tensor *axis_tensor = ctx.Input("Axis"); + framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); + const auto &axis_type = axis_tensor->dtype(); + if (framework::TransToProtoVarType(axis_type) == + framework::proto::VarType::INT32) { + axis = static_cast(cpu_axis.data()[0]); + } else if (framework::TransToProtoVarType(axis_type) == + framework::proto::VarType::INT64) { + axis = static_cast(cpu_axis.data()[0]); + } } output->mutable_data(ctx.GetPlace()); @@ -72,13 +83,13 @@ class GatherOpXPUKernel : public framework::OpKernel { r = xpu::gather( dev_ctx.x_context(), reinterpret_cast(x->data()), index->data(), reinterpret_cast(output->data()), - xshape, index->dims()[0], 0); + xshape, index->dims()[0], axis); } else { r = xpu::gather( dev_ctx.x_context(), reinterpret_cast(x->data()), index->data(), reinterpret_cast(output->data()), xshape, - index->dims()[0], 0); + index->dims()[0], axis); } PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, platform::errors::External( @@ -102,9 +113,19 @@ class GatherGradOpXPUKernel : public framework::OpKernel { auto *dout = ctx.Input(framework::GradVarName("Out")); auto &dev_ctx = ctx.template device_context(); + int axis = ctx.Attr("axis"); if (ctx.HasInput("Axis")) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Now, it doesn't support XPU with Axis.")); + Tensor cpu_axis; + const Tensor *axis_tensor = ctx.Input("Axis"); + framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis); + const auto &axis_type = axis_tensor->dtype(); + if (framework::TransToProtoVarType(axis_type) == + framework::proto::VarType::INT32) { + axis = static_cast(cpu_axis.data()[0]); + } else if (framework::TransToProtoVarType(axis_type) == + framework::proto::VarType::INT64) { + axis = static_cast(cpu_axis.data()[0]); + } } if (dout->numel() == 0) { return; @@ -139,7 +160,7 @@ class GatherGradOpXPUKernel : public framework::OpKernel { dev_ctx.x_context(), reinterpret_cast(dout->data()), index->data(), reinterpret_cast(dx->data()), - xshape, index->dims()[0], 0, overwrite); + xshape, index->dims()[0], axis, overwrite); } else { xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); int *index_int_ptr_l3 = @@ -156,7 +177,7 @@ class GatherGradOpXPUKernel : public framework::OpKernel { dev_ctx.x_context(), reinterpret_cast(dout->data()), index_int_ptr_l3, reinterpret_cast(dx->data()), xshape, index->dims()[0], - 0, overwrite); + axis, overwrite); } PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, platform::errors::External( diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 99f8e5ace9c..b94d0353e5d 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -38,6 +38,11 @@ XPUOpMap& get_kl2_ops() { {"argsort", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace())})}, {"assign_value", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"batch_norm_grad", @@ -209,6 +214,8 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"generate_proposals_v2", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"greater_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), diff --git a/python/paddle/fluid/tests/unittests/xpu/test_generate_proposals_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_generate_proposals_v2_op_xpu.py new file mode 100644 index 00000000000..764b4e81cce --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_generate_proposals_v2_op_xpu.py @@ -0,0 +1,544 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") + +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + +from op_test import OpTest +import copy +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +def box_coder(all_anchors, bbox_deltas, variances, pixel_offset=True): + """ + Decode proposals by anchors and bbox_deltas from RPN + """ + offset = 1 if pixel_offset else 0 + # proposals: xmin, ymin, xmax, ymax + proposals = np.zeros_like(bbox_deltas, dtype=np.float32) + + # anchor_loc: width, height, center_x, center_y + anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32) + + anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0] + offset + anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1] + offset + anchor_loc[:, 2] = all_anchors[:, 0] + 0.5 * anchor_loc[:, 0] + anchor_loc[:, 3] = all_anchors[:, 1] + 0.5 * anchor_loc[:, 1] + + # predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height + pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32) + if variances is not None: + for i in range(bbox_deltas.shape[0]): + pred_bbox[i, 0] = variances[i, 0] * bbox_deltas[i, 0] * anchor_loc[ + i, 0] + anchor_loc[i, 2] + pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[ + i, 1] + anchor_loc[i, 3] + pred_bbox[i, 2] = math.exp( + min(variances[i, 2] * bbox_deltas[i, 2], math.log( + 1000 / 16.0))) * anchor_loc[i, 0] + pred_bbox[i, 3] = math.exp( + min(variances[i, 3] * bbox_deltas[i, 3], math.log( + 1000 / 16.0))) * anchor_loc[i, 1] + else: + for i in range(bbox_deltas.shape[0]): + pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[ + i, 2] + pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[ + i, 3] + pred_bbox[i, 2] = math.exp( + min(bbox_deltas[i, 2], math.log(1000 / 16.0))) * anchor_loc[i, + 0] + pred_bbox[i, 3] = math.exp( + min(bbox_deltas[i, 3], math.log(1000 / 16.0))) * anchor_loc[i, + 1] + proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2 + proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2 + proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2 - offset + proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2 - offset + + return proposals + + +def clip_tiled_boxes(boxes, im_shape, pixel_offset=True): + """Clip boxes to image boundaries. im_shape is [height, width] and boxes + has shape (N, 4 * num_tiled_boxes).""" + assert boxes.shape[1] % 4 == 0, \ + 'boxes.shape[1] is {:d}, but must be divisible by 4.'.format( + boxes.shape[1] + ) + offset = 1 if pixel_offset else 0 + # x1 >= 0 + boxes[:, 0::4] = np.maximum( + np.minimum(boxes[:, 0::4], im_shape[1] - offset), 0) + # y1 >= 0 + boxes[:, 1::4] = np.maximum( + np.minimum(boxes[:, 1::4], im_shape[0] - offset), 0) + # x2 < im_shape[1] + boxes[:, 2::4] = np.maximum( + np.minimum(boxes[:, 2::4], im_shape[1] - offset), 0) + # y2 < im_shape[0] + boxes[:, 3::4] = np.maximum( + np.minimum(boxes[:, 3::4], im_shape[0] - offset), 0) + return boxes + + +def filter_boxes(boxes, min_size, im_shape, pixel_offset=True): + """Only keep boxes with both sides >= min_size and center within the image. + """ + # Scale min_size to match image scale + min_size = max(min_size, 1.0) + offset = 1 if pixel_offset else 0 + ws = boxes[:, 2] - boxes[:, 0] + offset + hs = boxes[:, 3] - boxes[:, 1] + offset + if pixel_offset: + x_ctr = boxes[:, 0] + ws / 2. + y_ctr = boxes[:, 1] + hs / 2. + keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_shape[ + 1]) & (y_ctr < im_shape[0]))[0] + else: + keep = np.where((ws >= min_size) & (hs >= min_size))[0] + return keep + + +def iou(box_a, box_b, pixel_offset=True): + """ + Apply intersection-over-union overlap between box_a and box_b + """ + xmin_a = min(box_a[0], box_a[2]) + ymin_a = min(box_a[1], box_a[3]) + xmax_a = max(box_a[0], box_a[2]) + ymax_a = max(box_a[1], box_a[3]) + + xmin_b = min(box_b[0], box_b[2]) + ymin_b = min(box_b[1], box_b[3]) + xmax_b = max(box_b[0], box_b[2]) + ymax_b = max(box_b[1], box_b[3]) + offset = 1 if pixel_offset else 0 + area_a = (ymax_a - ymin_a + offset) * (xmax_a - xmin_a + offset) + area_b = (ymax_b - ymin_b + offset) * (xmax_b - xmin_b + offset) + if area_a <= 0 and area_b <= 0: + return 0.0 + + xa = max(xmin_a, xmin_b) + ya = max(ymin_a, ymin_b) + xb = min(xmax_a, xmax_b) + yb = min(ymax_a, ymax_b) + + inter_area = max(xb - xa + offset, 0.0) * max(yb - ya + offset, 0.0) + + iou_ratio = inter_area / (area_a + area_b - inter_area) + + return iou_ratio + + +def nms(boxes, scores, nms_threshold, eta=1.0, pixel_offset=True): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + nms_threshold: (float) The overlap thresh for suppressing unnecessary + boxes. + eta: (float) The parameter for adaptive NMS. + Return: + The indices of the kept boxes with respect to num_priors. + """ + all_scores = copy.deepcopy(scores) + all_scores = all_scores.flatten() + + sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort') + sorted_scores = all_scores[sorted_indices] + selected_indices = [] + adaptive_threshold = nms_threshold + for i in range(sorted_scores.shape[0]): + idx = sorted_indices[i] + keep = True + for k in range(len(selected_indices)): + if keep: + kept_idx = selected_indices[k] + overlap = iou(boxes[idx], + boxes[kept_idx], + pixel_offset=pixel_offset) + keep = True if overlap <= adaptive_threshold else False + else: + break + if keep: + selected_indices.append(idx) + if keep and eta < 1 and adaptive_threshold > 0.5: + adaptive_threshold *= eta + return selected_indices + + +def proposal_for_one_image(im_shape, all_anchors, variances, bbox_deltas, + scores, pre_nms_topN, post_nms_topN, nms_thresh, + min_size, eta, pixel_offset): + # Transpose and reshape predicted bbox transformations to get them + # into the same order as the anchors: + # - bbox deltas will be (4 * A, H, W) format from conv output + # - transpose to (H, W, 4 * A) + # - reshape to (H * W * A, 4) where rows are ordered by (H, W, A) + # in slowest to fastest order to match the enumerated anchors + all_anchors = copy.deepcopy(all_anchors) + variances = copy.deepcopy(variances) + bbox_deltas = copy.deepcopy(bbox_deltas) + scores = copy.deepcopy(scores) + bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape(-1, 4) + all_anchors = all_anchors.reshape(-1, 4) + variances = variances.reshape(-1, 4) + # Same story for the scores: + # - scores are (A, H, W) format from conv output + # - transpose to (H, W, A) + # - reshape to (H * W * A, 1) where rows are ordered by (H, W, A) + # to match the order of anchors and bbox_deltas + scores = scores.transpose((1, 2, 0)).reshape(-1, 1) + + # sort all (proposal, score) pairs by score from highest to lowest + # take top pre_nms_topN (e.g. 6000) + if pre_nms_topN <= 0 or pre_nms_topN >= len(scores): + order = np.argsort(-scores.squeeze()) + else: + # Avoid sorting possibly large arrays; + # First partition to get top K unsorted + # and then sort just those + inds = np.argpartition(-scores.squeeze(), pre_nms_topN)[:pre_nms_topN] + order = np.argsort(-scores[inds].squeeze()) + order = inds[order] + scores = scores[order, :] + bbox_deltas = bbox_deltas[order, :] + all_anchors = all_anchors[order, :] + variances = variances[order, :] + proposals = box_coder(all_anchors, bbox_deltas, variances, pixel_offset) + # clip proposals to image (may result in proposals with zero area + # that will be removed in the next step) + proposals = clip_tiled_boxes(proposals, im_shape, pixel_offset) + # remove predicted boxes with height or width < min_size + keep = filter_boxes(proposals, min_size, im_shape, pixel_offset) + if len(keep) == 0: + proposals = np.zeros((1, 4)).astype('float32') + scores = np.zeros((1, 1)).astype('float32') + return proposals, scores + proposals = proposals[keep, :] + scores = scores[keep, :] + + # apply loose nms (e.g. threshold = 0.7) + # take post_nms_topN (e.g. 1000) + # return the top proposals + if nms_thresh > 0: + keep = nms(boxes=proposals, + scores=scores, + nms_threshold=nms_thresh, + eta=eta, + pixel_offset=pixel_offset) + if post_nms_topN > 0 and post_nms_topN < len(keep): + keep = keep[:post_nms_topN] + proposals = proposals[keep, :] + scores = scores[keep, :] + + return proposals, scores + + +def generate_proposals_v2_in_python(scores, bbox_deltas, im_shape, anchors, + variances, pre_nms_topN, post_nms_topN, + nms_thresh, min_size, eta, pixel_offset): + all_anchors = anchors.reshape(-1, 4) + rois = np.empty((0, 5), dtype=np.float32) + roi_probs = np.empty((0, 1), dtype=np.float32) + + rpn_rois = [] + rpn_roi_probs = [] + rois_num = [] + num_images = scores.shape[0] + for img_idx in range(num_images): + img_i_boxes, img_i_probs = proposal_for_one_image( + im_shape[img_idx, :], all_anchors, variances, + bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :], + pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta, + pixel_offset) + rois_num.append(img_i_probs.shape[0]) + rpn_rois.append(img_i_boxes) + rpn_roi_probs.append(img_i_probs) + + return rpn_rois, rpn_roi_probs, rois_num + + +def anchor_generator_in_python(input_feat, anchor_sizes, aspect_ratios, + variances, stride, offset): + num_anchors = len(aspect_ratios) * len(anchor_sizes) + layer_h = input_feat.shape[2] + layer_w = input_feat.shape[3] + out_dim = (layer_h, layer_w, num_anchors, 4) + out_anchors = np.zeros(out_dim).astype('float32') + + for h_idx in range(layer_h): + for w_idx in range(layer_w): + x_ctr = (w_idx * stride[0]) + offset * (stride[0] - 1) + y_ctr = (h_idx * stride[1]) + offset * (stride[1] - 1) + idx = 0 + for r in range(len(aspect_ratios)): + ar = aspect_ratios[r] + for s in range(len(anchor_sizes)): + anchor_size = anchor_sizes[s] + area = stride[0] * stride[1] + area_ratios = area / ar + base_w = np.round(np.sqrt(area_ratios)) + base_h = np.round(base_w * ar) + scale_w = anchor_size / stride[0] + scale_h = anchor_size / stride[1] + w = scale_w * base_w + h = scale_h * base_h + out_anchors[h_idx, w_idx, idx, :] = [ + (x_ctr - 0.5 * (w - 1)), (y_ctr - 0.5 * (h - 1)), + (x_ctr + 0.5 * (w - 1)), (y_ctr + 0.5 * (h - 1)) + ] + idx += 1 + + # set the variance. + out_var = np.tile(variances, (layer_h, layer_w, num_anchors, 1)) + out_anchors = out_anchors.astype('float32') + out_var = out_var.astype('float32') + return out_anchors, out_var + + +class XPUGenerateProposalsV2Op(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'generate_proposals_v2' + self.use_dynamic_create_class = False + + class TestGenerateProposalsV2Op(XPUOpTest): + def set_data(self): + self.init_input_shape() + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = { + 'Scores': self.scores, + 'BboxDeltas': self.bbox_deltas, + 'ImShape': self.im_shape.astype(self.dtype), + 'Anchors': self.anchors, + 'Variances': self.variances + } + + self.attrs = { + 'pre_nms_topN': self.pre_nms_topN, + 'post_nms_topN': self.post_nms_topN, + 'nms_thresh': self.nms_thresh, + 'min_size': self.min_size, + 'eta': self.eta, + 'pixel_offset': self.pixel_offset, + } + + self.outputs = { + 'RpnRois': (self.rpn_rois[0], [self.rois_num]), + 'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]), + } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + self.check_output_with_place(self.place) + + def setUp(self): + self.set_xpu() + self.op_type = "generate_proposals_v2" + self.place = paddle.XPUPlace(0) + self.init_dtype() + self.set_data() + + def set_xpu(self): + self.__class__.use_xpu = True + self.__class__.no_need_check_grad = True + + def init_input_shape(self): + self.input_feat_shape = (1, 20, 16, 16) + self.im_shape = np.array([[64, 64]]).astype(self.dtype) + + def init_dtype(self): + self.dtype = self.in_type + + def init_test_params(self): + self.pre_nms_topN = 12000 # train 12000, test 2000 + self.post_nms_topN = 5000 # train 6000, test 1000 + self.nms_thresh = 0.7 + self.min_size = 3.0 + self.eta = 1. + self.pixel_offset = True + + def init_test_input(self): + batch_size = self.input_feat_shape[0] + input_channels = self.input_feat_shape[1] + layer_h = self.input_feat_shape[2] + layer_w = self.input_feat_shape[3] + input_feat = np.random.random((batch_size, input_channels, layer_h, + layer_w)).astype(self.dtype) + self.anchors, self.variances = anchor_generator_in_python( + input_feat=input_feat, + anchor_sizes=[16., 32.], + aspect_ratios=[0.5, 1.0], + variances=[1.0, 1.0, 1.0, 1.0], + stride=[16.0, 16.0], + offset=0.5) + num_anchors = self.anchors.shape[2] + self.scores = np.random.random( + (batch_size, num_anchors, layer_h, layer_w)).astype(self.dtype) + self.bbox_deltas = np.random.random( + (batch_size, num_anchors * 4, layer_h, + layer_w)).astype(self.dtype) + + def init_test_output(self): + self.rpn_rois, self.rpn_roi_probs, self.rois_num = generate_proposals_v2_in_python( + self.scores, self.bbox_deltas, self.im_shape, self.anchors, + self.variances, self.pre_nms_topN, self.post_nms_topN, + self.nms_thresh, self.min_size, self.eta, self.pixel_offset) + + class TestGenerateProposalsV2OutLodOp(TestGenerateProposalsV2Op): + def set_data(self): + self.init_input_shape() + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = { + 'Scores': self.scores, + 'BboxDeltas': self.bbox_deltas, + 'ImShape': self.im_shape.astype(np.float32), + 'Anchors': self.anchors, + 'Variances': self.variances + } + + self.attrs = { + 'pre_nms_topN': self.pre_nms_topN, + 'post_nms_topN': self.post_nms_topN, + 'nms_thresh': self.nms_thresh, + 'min_size': self.min_size, + 'eta': self.eta, + 'pixel_offset': self.pixel_offset, + 'return_rois_num': True + } + + self.outputs = { + 'RpnRois': (self.rpn_rois[0], [self.rois_num]), + 'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]), + 'RpnRoisNum': (np.asarray( + self.rois_num, dtype=np.int32)) + } + + class TestGenerateProposalsV2OpNoBoxLeft(TestGenerateProposalsV2Op): + def init_test_params(self): + self.pre_nms_topN = 12000 # train 12000, test 2000 + self.post_nms_topN = 5000 # train 6000, test 1000 + self.nms_thresh = 0.7 + self.min_size = 1000.0 + self.eta = 1. + self.pixel_offset = True + + class TestGenerateProposalsV2OpNoOffset(TestGenerateProposalsV2Op): + def init_test_params(self): + self.pre_nms_topN = 12000 # train 12000, test 2000 + self.post_nms_topN = 5000 # train 6000, test 1000 + self.nms_thresh = 0.7 + self.min_size = 3.0 + self.eta = 1. + self.pixel_offset = False + + # """ + class TestGenerateProposalsV2OpMaskRcnn1XPU(TestGenerateProposalsV2Op): + def init_input_shape(self): + self.input_feat_shape = (1, 20, 48, 64) + self.im_shape = np.array([[768, 1024]]).astype(self.dtype) + + def init_test_params(self): + self.pre_nms_topN = 12000 # train 12000, test 2000 + self.post_nms_topN = 2000 # train 6000, test 1000 + self.nms_thresh = 0.7 + self.min_size = 0.0 + self.eta = 1. + self.pixel_offset = False + + def init_test_input(self): + batch_size = self.input_feat_shape[0] + input_channels = self.input_feat_shape[1] + layer_h = self.input_feat_shape[2] + layer_w = self.input_feat_shape[3] + input_feat = np.random.random((batch_size, input_channels, layer_h, + layer_w)).astype(self.dtype) + self.anchors, self.variances = anchor_generator_in_python( + input_feat=input_feat, + anchor_sizes=[32, 64, 128, 256, 512], + aspect_ratios=[0.5, 1.0, 2.0], + variances=[1.0, 1.0, 1.0, 1.0], + stride=[16.0, 16.0], + offset=0.5) + num_anchors = self.anchors.shape[2] + self.scores = np.random.random( + (batch_size, num_anchors, layer_h, layer_w)).astype(self.dtype) + self.bbox_deltas = np.random.random( + (batch_size, num_anchors * 4, layer_h, + layer_w)).astype(self.dtype) + self.anchors = self.anchors.reshape(-1, 4) + self.variances = self.variances.reshape(-1, 4) + + def set_data(self): + np.random.seed(1) + self.init_input_shape() + self.init_test_params() + self.init_test_input() + self.init_test_output() + + self.inputs = { + 'Scores': self.scores, + 'BboxDeltas': self.bbox_deltas, + 'ImShape': self.im_shape.astype(np.float32), + 'Anchors': self.anchors, + 'Variances': self.variances + } + + self.attrs = { + 'pre_nms_topN': self.pre_nms_topN, + 'post_nms_topN': self.post_nms_topN, + 'nms_thresh': self.nms_thresh, + 'min_size': self.min_size, + 'eta': self.eta, + 'pixel_offset': self.pixel_offset, + 'return_rois_num': True + } + + self.outputs = { + 'RpnRois': (self.rpn_rois[0], [self.rois_num]), + 'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]), + 'RpnRoisNum': (np.asarray( + self.rois_num, dtype=np.int32)) + } + + +support_types = get_xpu_op_support_types('generate_proposals_v2') +for stype in support_types: + create_test_class( + globals(), + XPUGenerateProposalsV2Op, + stype, + test_grad=False, + ignore_deivce_version=[core.XPUVersion.XPU1]) + +if __name__ == '__main__': + unittest.main() -- GitLab