未验证 提交 ff22a9c4 编写于 作者: L Leo Guo 提交者: GitHub

Add generate_proposals_v2 op and expend function of gather op for kunlun. *test=kunlun (#43162)

* Add generate_proposals_v2 op and unittest for kunlun. *test=kunlun

* Add the assign op to xpu2_op_list and expand the function of gather op. Add the unit-test of generate_proposals_v2. *test=kunlun
上级 4d3b7d7d
......@@ -39,12 +39,14 @@ endif()
if(WITH_XPU)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_xpu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_xpu.cc)
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc generate_proposals_v2_op_xpu.cc)
elseif(WITH_ASCEND_CL)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_npu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu prior_box_op_npu.cc)
else()
detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
# detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc)
endif()
detection_library(bipartite_match_op SRCS bipartite_match_op.cc)
......@@ -81,7 +83,9 @@ if(WITH_GPU OR WITH_ROCM)
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS ${TMPDEPS})
else()
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
if(NOT WITH_XPU)
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc)
endif()
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc)
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc)
endif()
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <paddle/fluid/memory/allocation/allocator.h>
#include <stdio.h>
#include <string>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
namespace {
template <typename T>
static void SortDescending(const platform::XPUDeviceContext &dev_ctx,
const Tensor &value, Tensor *index_out,
int pre_nms_top_n) {
auto *value_data = value.data<T>();
auto place = dev_ctx.GetPlace();
auto cpu_place = platform::CPUPlace();
Tensor scores_slice_cpu;
scores_slice_cpu.Resize({value.numel()});
auto *scores_slice_cpu_data = scores_slice_cpu.mutable_data<T>(cpu_place);
memory::Copy(cpu_place, scores_slice_cpu_data, place, value_data,
sizeof(T) * value.numel());
// Sort index
Tensor index_t;
int *index = index_t.mutable_data<int>({value.numel()}, cpu_place);
for (int i = 0; i < value.numel(); ++i) {
index[i] = i;
}
auto compare = [scores_slice_cpu_data](const int64_t &i, const int64_t &j) {
return scores_slice_cpu_data[i] > scores_slice_cpu_data[j];
};
if (pre_nms_top_n <= 0 || pre_nms_top_n >= value.numel()) {
std::sort(index, index + value.numel(), compare);
} else {
std::nth_element(index, index + pre_nms_top_n, index + value.numel(),
compare);
std::sort(index, index + pre_nms_top_n, compare);
index_t.Resize({pre_nms_top_n});
}
int *idx_out =
index_out->mutable_data<int>({index_t.numel()}, dev_ctx.GetPlace());
memory::Copy(place, idx_out, cpu_place, index, sizeof(T) * index_t.numel());
}
template <typename T>
static std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::XPUDeviceContext &dev_ctx, const Tensor &im_shape,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas, // [M, 4]
const Tensor &scores, // [N, 1]
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
float eta, bool pixel_offset) {
// 1. pre nms
Tensor index_sort;
SortDescending<T>(dev_ctx, scores, &index_sort, pre_nms_top_n);
Tensor scores_sel, bbox_sel, anchor_sel, var_sel;
scores_sel.mutable_data<T>({index_sort.numel(), 1}, dev_ctx.GetPlace());
bbox_sel.mutable_data<T>({index_sort.numel(), 4}, dev_ctx.GetPlace());
anchor_sel.mutable_data<T>({index_sort.numel(), 4}, dev_ctx.GetPlace());
var_sel.mutable_data<T>({index_sort.numel(), 4}, dev_ctx.GetPlace());
int r = xpu::gather<T>(dev_ctx.x_context(), scores.data<T>(),
index_sort.data<int>(), scores_sel.data<T>(),
{static_cast<int>(scores.numel()), 1},
index_sort.numel(), 0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(gather) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::gather<T>(dev_ctx.x_context(), bbox_deltas.data<T>(),
index_sort.data<int>(), bbox_sel.data<T>(),
{static_cast<int>(bbox_deltas.numel()) / 4, 4},
index_sort.numel(), 0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(gather) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::gather<T>(dev_ctx.x_context(), anchors.data<T>(),
index_sort.data<int>(), anchor_sel.data<T>(),
{static_cast<int>(anchors.numel()) / 4, 4},
index_sort.numel(), 0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(gather) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::gather<T>(dev_ctx.x_context(), variances.data<T>(),
index_sort.data<int>(), var_sel.data<T>(),
{static_cast<int>(variances.numel()) / 4, 4},
index_sort.numel(), 0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(gather) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
int num = scores.numel();
int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num) ? scores.numel()
: pre_nms_top_n;
scores_sel.Resize({pre_nms_num, 1});
index_sort.Resize({pre_nms_num, 1});
// 2. box decode and clipping
Tensor proposals;
proposals.mutable_data<T>({pre_nms_num, 4}, dev_ctx.GetPlace());
r = xpu::box_decoder<T>(dev_ctx.x_context(), anchor_sel.data<T>(),
var_sel.data<T>(), bbox_sel.data<T>(),
proposals.data<T>(), pre_nms_num, !pixel_offset, true,
im_shape.data<T>());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(box_decoder) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
// 3. filter
Tensor keep_index, keep_num_t;
keep_index.mutable_data<int>({pre_nms_num}, dev_ctx.GetPlace());
keep_num_t.mutable_data<int>({1}, dev_ctx.GetPlace());
min_size = std::max(min_size, 1.0f);
r = xpu::remove_small_boxes<T>(dev_ctx.x_context(), proposals.data<T>(),
im_shape.data<T>(), keep_index.data<int>(),
keep_num_t.data<int>(), pre_nms_num, min_size,
false, pixel_offset);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(remove_small_boxes) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
int keep_num;
const auto xpu_place = dev_ctx.GetPlace();
memory::Copy(platform::CPUPlace(), &keep_num, xpu_place,
keep_num_t.data<int>(), sizeof(int));
keep_index.Resize({keep_num});
Tensor scores_filter, proposals_filter;
// Handle the case when there is no keep index left
if (keep_num == 0) {
phi::funcs::SetConstant<platform::XPUDeviceContext, T> set_zero;
proposals_filter.mutable_data<T>({1, 4}, dev_ctx.GetPlace());
scores_filter.mutable_data<T>({1, 1}, dev_ctx.GetPlace());
set_zero(dev_ctx, &proposals_filter, static_cast<T>(0));
set_zero(dev_ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(proposals_filter, scores_filter);
}
proposals_filter.mutable_data<T>({keep_num, 4}, dev_ctx.GetPlace());
scores_filter.mutable_data<T>({keep_num, 1}, dev_ctx.GetPlace());
r = xpu::gather<T>(dev_ctx.x_context(), proposals.data<T>(),
keep_index.data<int>(), proposals_filter.data<T>(),
{pre_nms_num, 4}, keep_num, 0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(gather) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::gather<T>(dev_ctx.x_context(), scores_sel.data<T>(),
keep_index.data<int>(), scores_filter.data<T>(),
{pre_nms_num, 1}, keep_num, 0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(gather) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
if (nms_thresh <= 0) {
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
return std::make_pair(proposals_filter, scores_filter);
}
// 4. nms
int nms_keep_num = 0;
r = xpu::nms<T>(dev_ctx.x_context(), proposals_filter.data<T>(), nullptr,
keep_index.data<int>(), 1, 1, keep_num, -1, nms_thresh, -1, 0,
&nms_keep_num, pixel_offset);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(nms) return the"
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
if (post_nms_top_n > 0 && post_nms_top_n < nms_keep_num) {
keep_index.Resize({post_nms_top_n});
} else {
keep_index.Resize({nms_keep_num});
}
Tensor scores_nms, proposals_nms;
proposals_nms.mutable_data<T>({keep_index.numel(), 4}, dev_ctx.GetPlace());
scores_nms.mutable_data<T>({keep_index.numel(), 1}, dev_ctx.GetPlace());
r = xpu::gather<T>(dev_ctx.x_context(), proposals_filter.data<T>(),
keep_index.data<int>(), proposals_nms.data<T>(),
{keep_num, 4}, keep_index.numel(), 0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(gather) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::gather<T>(dev_ctx.x_context(), scores_filter.data<T>(),
keep_index.data<int>(), scores_nms.data<T>(),
{keep_num, 1}, keep_index.numel(), 0);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(gather) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
return std::make_pair(proposals_nms, scores_nms);
}
} // namespace
template <typename DeviceContext, typename T>
class XPUGenerateProposalsV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_shape = context.Input<Tensor>("ImShape");
auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
"Anchors", "GenerateProposals");
auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
"Input", "Variances", "GenerateProposals");
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
int post_nms_top_n = context.Attr<int>("post_nms_topN");
float nms_thresh = context.Attr<float>("nms_thresh");
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
bool pixel_offset = context.Attr<bool>("pixel_offset");
PADDLE_ENFORCE_GE(eta, 1.,
platform::errors::InvalidArgument(
"Not support adaptive NMS. The attribute 'eta' "
"should not less than 1. But received eta=[%d]",
eta));
auto &dev_ctx = context.template device_context<DeviceContext>();
auto scores_dim = scores->dims();
// the shape of bbox score
int num = scores_dim[0];
int c_score = scores_dim[1];
int h_score = scores_dim[2];
int w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims();
int c_bbox = bbox_dim[1];
int h_bbox = bbox_dim[2];
int w_bbox = bbox_dim[3];
Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
dev_ctx.GetPlace());
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());
std::vector<int> axis = {0, 2, 3, 1};
int r = xpu::transpose<T>(dev_ctx.x_context(), bbox_deltas->data<T>(),
bbox_deltas_swap.data<T>(),
{num, c_bbox, h_bbox, w_bbox}, axis);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(transpose) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::transpose<T>(dev_ctx.x_context(), scores->data<T>(),
scores_swap.data<T>(),
{num, c_score, h_score, w_score}, axis);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(transpose) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
// output
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());
T *rpn_rois_data = rpn_rois->data<T>();
T *rpn_roi_probs_data = rpn_roi_probs->data<T>();
auto place = dev_ctx.GetPlace();
auto cpu_place = platform::CPUPlace();
int num_proposals = 0;
std::vector<size_t> offset(1, 0);
std::vector<int> tmp_num;
for (int64_t i = 0; i < num; ++i) {
Tensor im_shape_slice = im_shape->Slice(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
Tensor scores_slice = scores_swap.Slice(i, i + 1);
bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> box_score_pair = ProposalForOneImage<T>(
dev_ctx, im_shape_slice, anchors, variances, bbox_deltas_slice,
scores_slice, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size,
eta, pixel_offset);
Tensor &proposals = box_score_pair.first;
Tensor &scores = box_score_pair.second;
memory::Copy(place, rpn_rois_data + num_proposals * 4, place,
proposals.data<T>(), sizeof(T) * proposals.numel());
memory::Copy(place, rpn_roi_probs_data + num_proposals, place,
scores.data<T>(), sizeof(T) * scores.numel());
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
num_proposals += proposals.dims()[0];
offset.emplace_back(num_proposals);
tmp_num.push_back(proposals.dims()[0]);
}
if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int *num_data = rpn_rois_num->data<int>();
memory::Copy(place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num);
rpn_rois_num->Resize({num});
}
framework::LoD lod;
lod.emplace_back(offset);
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(generate_proposals_v2,
ops::XPUGenerateProposalsV2Kernel<
paddle::platform::XPUDeviceContext, float>);
#endif // PADDLE_WITH_XPU
......@@ -38,9 +38,20 @@ class GatherOpXPUKernel : public framework::OpKernel<T> {
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Now, it doesn't support XPU with Axis."));
Tensor cpu_axis;
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type = axis_tensor->dtype();
if (framework::TransToProtoVarType(axis_type) ==
framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (framework::TransToProtoVarType(axis_type) ==
framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
}
output->mutable_data<T>(ctx.GetPlace());
......@@ -72,13 +83,13 @@ class GatherOpXPUKernel : public framework::OpKernel<T> {
r = xpu::gather<XPUType, int>(
dev_ctx.x_context(), reinterpret_cast<const XPUType *>(x->data<T>()),
index->data<int>(), reinterpret_cast<XPUType *>(output->data<T>()),
xshape, index->dims()[0], 0);
xshape, index->dims()[0], axis);
} else {
r = xpu::gather<XPUType, int64_t>(
dev_ctx.x_context(), reinterpret_cast<const XPUType *>(x->data<T>()),
index->data<int64_t>(),
reinterpret_cast<XPUType *>(output->data<T>()), xshape,
index->dims()[0], 0);
index->dims()[0], axis);
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
......@@ -102,9 +113,19 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Now, it doesn't support XPU with Axis."));
Tensor cpu_axis;
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type = axis_tensor->dtype();
if (framework::TransToProtoVarType(axis_type) ==
framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (framework::TransToProtoVarType(axis_type) ==
framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
}
if (dout->numel() == 0) {
return;
......@@ -139,7 +160,7 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(dout->data<T>()),
index->data<int>(), reinterpret_cast<XPUType *>(dx->data<T>()),
xshape, index->dims()[0], 0, overwrite);
xshape, index->dims()[0], axis, overwrite);
} else {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int *index_int_ptr_l3 =
......@@ -156,7 +177,7 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(dout->data<T>()), index_int_ptr_l3,
reinterpret_cast<XPUType *>(dx->data<T>()), xshape, index->dims()[0],
0, overwrite);
axis, overwrite);
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
......
......@@ -38,6 +38,11 @@ XPUOpMap& get_kl2_ops() {
{"argsort", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"assign", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"assign_value",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad",
......@@ -209,6 +214,8 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"generate_proposals_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"greater_equal",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest
import copy
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
def box_coder(all_anchors, bbox_deltas, variances, pixel_offset=True):
"""
Decode proposals by anchors and bbox_deltas from RPN
"""
offset = 1 if pixel_offset else 0
# proposals: xmin, ymin, xmax, ymax
proposals = np.zeros_like(bbox_deltas, dtype=np.float32)
# anchor_loc: width, height, center_x, center_y
anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32)
anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0] + offset
anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1] + offset
anchor_loc[:, 2] = all_anchors[:, 0] + 0.5 * anchor_loc[:, 0]
anchor_loc[:, 3] = all_anchors[:, 1] + 0.5 * anchor_loc[:, 1]
# predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height
pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32)
if variances is not None:
for i in range(bbox_deltas.shape[0]):
pred_bbox[i, 0] = variances[i, 0] * bbox_deltas[i, 0] * anchor_loc[
i, 0] + anchor_loc[i, 2]
pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[
i, 1] + anchor_loc[i, 3]
pred_bbox[i, 2] = math.exp(
min(variances[i, 2] * bbox_deltas[i, 2], math.log(
1000 / 16.0))) * anchor_loc[i, 0]
pred_bbox[i, 3] = math.exp(
min(variances[i, 3] * bbox_deltas[i, 3], math.log(
1000 / 16.0))) * anchor_loc[i, 1]
else:
for i in range(bbox_deltas.shape[0]):
pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[
i, 2]
pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[
i, 3]
pred_bbox[i, 2] = math.exp(
min(bbox_deltas[i, 2], math.log(1000 / 16.0))) * anchor_loc[i,
0]
pred_bbox[i, 3] = math.exp(
min(bbox_deltas[i, 3], math.log(1000 / 16.0))) * anchor_loc[i,
1]
proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2
proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2
proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2 - offset
proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2 - offset
return proposals
def clip_tiled_boxes(boxes, im_shape, pixel_offset=True):
"""Clip boxes to image boundaries. im_shape is [height, width] and boxes
has shape (N, 4 * num_tiled_boxes)."""
assert boxes.shape[1] % 4 == 0, \
'boxes.shape[1] is {:d}, but must be divisible by 4.'.format(
boxes.shape[1]
)
offset = 1 if pixel_offset else 0
# x1 >= 0
boxes[:, 0::4] = np.maximum(
np.minimum(boxes[:, 0::4], im_shape[1] - offset), 0)
# y1 >= 0
boxes[:, 1::4] = np.maximum(
np.minimum(boxes[:, 1::4], im_shape[0] - offset), 0)
# x2 < im_shape[1]
boxes[:, 2::4] = np.maximum(
np.minimum(boxes[:, 2::4], im_shape[1] - offset), 0)
# y2 < im_shape[0]
boxes[:, 3::4] = np.maximum(
np.minimum(boxes[:, 3::4], im_shape[0] - offset), 0)
return boxes
def filter_boxes(boxes, min_size, im_shape, pixel_offset=True):
"""Only keep boxes with both sides >= min_size and center within the image.
"""
# Scale min_size to match image scale
min_size = max(min_size, 1.0)
offset = 1 if pixel_offset else 0
ws = boxes[:, 2] - boxes[:, 0] + offset
hs = boxes[:, 3] - boxes[:, 1] + offset
if pixel_offset:
x_ctr = boxes[:, 0] + ws / 2.
y_ctr = boxes[:, 1] + hs / 2.
keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_shape[
1]) & (y_ctr < im_shape[0]))[0]
else:
keep = np.where((ws >= min_size) & (hs >= min_size))[0]
return keep
def iou(box_a, box_b, pixel_offset=True):
"""
Apply intersection-over-union overlap between box_a and box_b
"""
xmin_a = min(box_a[0], box_a[2])
ymin_a = min(box_a[1], box_a[3])
xmax_a = max(box_a[0], box_a[2])
ymax_a = max(box_a[1], box_a[3])
xmin_b = min(box_b[0], box_b[2])
ymin_b = min(box_b[1], box_b[3])
xmax_b = max(box_b[0], box_b[2])
ymax_b = max(box_b[1], box_b[3])
offset = 1 if pixel_offset else 0
area_a = (ymax_a - ymin_a + offset) * (xmax_a - xmin_a + offset)
area_b = (ymax_b - ymin_b + offset) * (xmax_b - xmin_b + offset)
if area_a <= 0 and area_b <= 0:
return 0.0
xa = max(xmin_a, xmin_b)
ya = max(ymin_a, ymin_b)
xb = min(xmax_a, xmax_b)
yb = min(ymax_a, ymax_b)
inter_area = max(xb - xa + offset, 0.0) * max(yb - ya + offset, 0.0)
iou_ratio = inter_area / (area_a + area_b - inter_area)
return iou_ratio
def nms(boxes, scores, nms_threshold, eta=1.0, pixel_offset=True):
"""Apply non-maximum suppression at test time to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
scores: (tensor) The class predscores for the img, Shape:[num_priors].
nms_threshold: (float) The overlap thresh for suppressing unnecessary
boxes.
eta: (float) The parameter for adaptive NMS.
Return:
The indices of the kept boxes with respect to num_priors.
"""
all_scores = copy.deepcopy(scores)
all_scores = all_scores.flatten()
sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
sorted_scores = all_scores[sorted_indices]
selected_indices = []
adaptive_threshold = nms_threshold
for i in range(sorted_scores.shape[0]):
idx = sorted_indices[i]
keep = True
for k in range(len(selected_indices)):
if keep:
kept_idx = selected_indices[k]
overlap = iou(boxes[idx],
boxes[kept_idx],
pixel_offset=pixel_offset)
keep = True if overlap <= adaptive_threshold else False
else:
break
if keep:
selected_indices.append(idx)
if keep and eta < 1 and adaptive_threshold > 0.5:
adaptive_threshold *= eta
return selected_indices
def proposal_for_one_image(im_shape, all_anchors, variances, bbox_deltas,
scores, pre_nms_topN, post_nms_topN, nms_thresh,
min_size, eta, pixel_offset):
# Transpose and reshape predicted bbox transformations to get them
# into the same order as the anchors:
# - bbox deltas will be (4 * A, H, W) format from conv output
# - transpose to (H, W, 4 * A)
# - reshape to (H * W * A, 4) where rows are ordered by (H, W, A)
# in slowest to fastest order to match the enumerated anchors
all_anchors = copy.deepcopy(all_anchors)
variances = copy.deepcopy(variances)
bbox_deltas = copy.deepcopy(bbox_deltas)
scores = copy.deepcopy(scores)
bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape(-1, 4)
all_anchors = all_anchors.reshape(-1, 4)
variances = variances.reshape(-1, 4)
# Same story for the scores:
# - scores are (A, H, W) format from conv output
# - transpose to (H, W, A)
# - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
# to match the order of anchors and bbox_deltas
scores = scores.transpose((1, 2, 0)).reshape(-1, 1)
# sort all (proposal, score) pairs by score from highest to lowest
# take top pre_nms_topN (e.g. 6000)
if pre_nms_topN <= 0 or pre_nms_topN >= len(scores):
order = np.argsort(-scores.squeeze())
else:
# Avoid sorting possibly large arrays;
# First partition to get top K unsorted
# and then sort just those
inds = np.argpartition(-scores.squeeze(), pre_nms_topN)[:pre_nms_topN]
order = np.argsort(-scores[inds].squeeze())
order = inds[order]
scores = scores[order, :]
bbox_deltas = bbox_deltas[order, :]
all_anchors = all_anchors[order, :]
variances = variances[order, :]
proposals = box_coder(all_anchors, bbox_deltas, variances, pixel_offset)
# clip proposals to image (may result in proposals with zero area
# that will be removed in the next step)
proposals = clip_tiled_boxes(proposals, im_shape, pixel_offset)
# remove predicted boxes with height or width < min_size
keep = filter_boxes(proposals, min_size, im_shape, pixel_offset)
if len(keep) == 0:
proposals = np.zeros((1, 4)).astype('float32')
scores = np.zeros((1, 1)).astype('float32')
return proposals, scores
proposals = proposals[keep, :]
scores = scores[keep, :]
# apply loose nms (e.g. threshold = 0.7)
# take post_nms_topN (e.g. 1000)
# return the top proposals
if nms_thresh > 0:
keep = nms(boxes=proposals,
scores=scores,
nms_threshold=nms_thresh,
eta=eta,
pixel_offset=pixel_offset)
if post_nms_topN > 0 and post_nms_topN < len(keep):
keep = keep[:post_nms_topN]
proposals = proposals[keep, :]
scores = scores[keep, :]
return proposals, scores
def generate_proposals_v2_in_python(scores, bbox_deltas, im_shape, anchors,
variances, pre_nms_topN, post_nms_topN,
nms_thresh, min_size, eta, pixel_offset):
all_anchors = anchors.reshape(-1, 4)
rois = np.empty((0, 5), dtype=np.float32)
roi_probs = np.empty((0, 1), dtype=np.float32)
rpn_rois = []
rpn_roi_probs = []
rois_num = []
num_images = scores.shape[0]
for img_idx in range(num_images):
img_i_boxes, img_i_probs = proposal_for_one_image(
im_shape[img_idx, :], all_anchors, variances,
bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :],
pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta,
pixel_offset)
rois_num.append(img_i_probs.shape[0])
rpn_rois.append(img_i_boxes)
rpn_roi_probs.append(img_i_probs)
return rpn_rois, rpn_roi_probs, rois_num
def anchor_generator_in_python(input_feat, anchor_sizes, aspect_ratios,
variances, stride, offset):
num_anchors = len(aspect_ratios) * len(anchor_sizes)
layer_h = input_feat.shape[2]
layer_w = input_feat.shape[3]
out_dim = (layer_h, layer_w, num_anchors, 4)
out_anchors = np.zeros(out_dim).astype('float32')
for h_idx in range(layer_h):
for w_idx in range(layer_w):
x_ctr = (w_idx * stride[0]) + offset * (stride[0] - 1)
y_ctr = (h_idx * stride[1]) + offset * (stride[1] - 1)
idx = 0
for r in range(len(aspect_ratios)):
ar = aspect_ratios[r]
for s in range(len(anchor_sizes)):
anchor_size = anchor_sizes[s]
area = stride[0] * stride[1]
area_ratios = area / ar
base_w = np.round(np.sqrt(area_ratios))
base_h = np.round(base_w * ar)
scale_w = anchor_size / stride[0]
scale_h = anchor_size / stride[1]
w = scale_w * base_w
h = scale_h * base_h
out_anchors[h_idx, w_idx, idx, :] = [
(x_ctr - 0.5 * (w - 1)), (y_ctr - 0.5 * (h - 1)),
(x_ctr + 0.5 * (w - 1)), (y_ctr + 0.5 * (h - 1))
]
idx += 1
# set the variance.
out_var = np.tile(variances, (layer_h, layer_w, num_anchors, 1))
out_anchors = out_anchors.astype('float32')
out_var = out_var.astype('float32')
return out_anchors, out_var
class XPUGenerateProposalsV2Op(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'generate_proposals_v2'
self.use_dynamic_create_class = False
class TestGenerateProposalsV2Op(XPUOpTest):
def set_data(self):
self.init_input_shape()
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {
'Scores': self.scores,
'BboxDeltas': self.bbox_deltas,
'ImShape': self.im_shape.astype(self.dtype),
'Anchors': self.anchors,
'Variances': self.variances
}
self.attrs = {
'pre_nms_topN': self.pre_nms_topN,
'post_nms_topN': self.post_nms_topN,
'nms_thresh': self.nms_thresh,
'min_size': self.min_size,
'eta': self.eta,
'pixel_offset': self.pixel_offset,
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
self.check_output_with_place(self.place)
def setUp(self):
self.set_xpu()
self.op_type = "generate_proposals_v2"
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.set_data()
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
def init_input_shape(self):
self.input_feat_shape = (1, 20, 16, 16)
self.im_shape = np.array([[64, 64]]).astype(self.dtype)
def init_dtype(self):
self.dtype = self.in_type
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 3.0
self.eta = 1.
self.pixel_offset = True
def init_test_input(self):
batch_size = self.input_feat_shape[0]
input_channels = self.input_feat_shape[1]
layer_h = self.input_feat_shape[2]
layer_w = self.input_feat_shape[3]
input_feat = np.random.random((batch_size, input_channels, layer_h,
layer_w)).astype(self.dtype)
self.anchors, self.variances = anchor_generator_in_python(
input_feat=input_feat,
anchor_sizes=[16., 32.],
aspect_ratios=[0.5, 1.0],
variances=[1.0, 1.0, 1.0, 1.0],
stride=[16.0, 16.0],
offset=0.5)
num_anchors = self.anchors.shape[2]
self.scores = np.random.random(
(batch_size, num_anchors, layer_h, layer_w)).astype(self.dtype)
self.bbox_deltas = np.random.random(
(batch_size, num_anchors * 4, layer_h,
layer_w)).astype(self.dtype)
def init_test_output(self):
self.rpn_rois, self.rpn_roi_probs, self.rois_num = generate_proposals_v2_in_python(
self.scores, self.bbox_deltas, self.im_shape, self.anchors,
self.variances, self.pre_nms_topN, self.post_nms_topN,
self.nms_thresh, self.min_size, self.eta, self.pixel_offset)
class TestGenerateProposalsV2OutLodOp(TestGenerateProposalsV2Op):
def set_data(self):
self.init_input_shape()
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {
'Scores': self.scores,
'BboxDeltas': self.bbox_deltas,
'ImShape': self.im_shape.astype(np.float32),
'Anchors': self.anchors,
'Variances': self.variances
}
self.attrs = {
'pre_nms_topN': self.pre_nms_topN,
'post_nms_topN': self.post_nms_topN,
'nms_thresh': self.nms_thresh,
'min_size': self.min_size,
'eta': self.eta,
'pixel_offset': self.pixel_offset,
'return_rois_num': True
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
'RpnRoisNum': (np.asarray(
self.rois_num, dtype=np.int32))
}
class TestGenerateProposalsV2OpNoBoxLeft(TestGenerateProposalsV2Op):
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 1000.0
self.eta = 1.
self.pixel_offset = True
class TestGenerateProposalsV2OpNoOffset(TestGenerateProposalsV2Op):
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 3.0
self.eta = 1.
self.pixel_offset = False
# """
class TestGenerateProposalsV2OpMaskRcnn1XPU(TestGenerateProposalsV2Op):
def init_input_shape(self):
self.input_feat_shape = (1, 20, 48, 64)
self.im_shape = np.array([[768, 1024]]).astype(self.dtype)
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 2000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 0.0
self.eta = 1.
self.pixel_offset = False
def init_test_input(self):
batch_size = self.input_feat_shape[0]
input_channels = self.input_feat_shape[1]
layer_h = self.input_feat_shape[2]
layer_w = self.input_feat_shape[3]
input_feat = np.random.random((batch_size, input_channels, layer_h,
layer_w)).astype(self.dtype)
self.anchors, self.variances = anchor_generator_in_python(
input_feat=input_feat,
anchor_sizes=[32, 64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0],
variances=[1.0, 1.0, 1.0, 1.0],
stride=[16.0, 16.0],
offset=0.5)
num_anchors = self.anchors.shape[2]
self.scores = np.random.random(
(batch_size, num_anchors, layer_h, layer_w)).astype(self.dtype)
self.bbox_deltas = np.random.random(
(batch_size, num_anchors * 4, layer_h,
layer_w)).astype(self.dtype)
self.anchors = self.anchors.reshape(-1, 4)
self.variances = self.variances.reshape(-1, 4)
def set_data(self):
np.random.seed(1)
self.init_input_shape()
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {
'Scores': self.scores,
'BboxDeltas': self.bbox_deltas,
'ImShape': self.im_shape.astype(np.float32),
'Anchors': self.anchors,
'Variances': self.variances
}
self.attrs = {
'pre_nms_topN': self.pre_nms_topN,
'post_nms_topN': self.post_nms_topN,
'nms_thresh': self.nms_thresh,
'min_size': self.min_size,
'eta': self.eta,
'pixel_offset': self.pixel_offset,
'return_rois_num': True
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
'RpnRoisNum': (np.asarray(
self.rois_num, dtype=np.int32))
}
support_types = get_xpu_op_support_types('generate_proposals_v2')
for stype in support_types:
create_test_class(
globals(),
XPUGenerateProposalsV2Op,
stype,
test_grad=False,
ignore_deivce_version=[core.XPUVersion.XPU1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册