未验证 提交 90a33ddd 编写于 作者: S sweetsky0901 提交者: GitHub

Merge pull request #6488 from sweetsky0901/detection_output

add Detection output op for SSD
......@@ -53,7 +53,6 @@ function(op_library TARGET)
if (${op_library_DEPS_len} GREATER 0)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
endif()
if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
......@@ -187,6 +186,36 @@ endfunction()
add_subdirectory(math)
add_subdirectory(nccl)
set(DEPS_OPS
cond_op
cross_entropy_op
recurrent_op
softmax_with_cross_entropy_op
softmax_op
sequence_softmax_op
sum_op
pool_op
maxout_op
unpool_op
pool_with_index_op
conv_op
conv_transpose_op
nccl_op
sequence_conv_op
sequence_pool_op
lod_rank_table_op
lod_tensor_to_array_op
array_to_lod_tensor_op
max_sequence_len_op
lstm_op
gru_op
adagrad_op
sgd_op
save_op
load_op
send_op
recv_op
detection_output_op)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
else()
......@@ -210,6 +239,7 @@ op_library(cond_op DEPS framework_proto tensor net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(softmax_op DEPS softmax)
op_library(detection_output_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax)
op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/detection_output_op.h"
namespace paddle {
namespace operators {
class DetectionOutputOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DetectionOutputOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Loc",
"(Tensor) The input tensor of detection_output operator."
"The input predict locations"
"The format of input tensor is kNCHW. Where K is priorbox point "
"numbers,"
"N is How many boxes are there on each point, "
"C is 4, H and W both are 1.");
AddInput("Conf",
"(Tensor) The input tensor of detection_output operator."
"The input priorbox confidence."
"The format of input tensor is kNCHW. Where K is priorbox point "
"numbers,"
"N is How many boxes are there on each point, "
"C is the number of classes, H and W both are 1.");
AddInput("PriorBox",
"(Tensor) The input tensor of detection_output operator."
"The format of input tensor is the position and variance "
"of the boxes");
AddOutput("Out",
"(Tensor) The output tensor of detection_output operator.");
AddAttr<int>("background_label_id", "(int), The background class index.");
AddAttr<int>("num_classes", "(int), The number of the classification.");
AddAttr<float>("nms_threshold",
"(float), The Non-maximum suppression threshold.");
AddAttr<float>("confidence_threshold",
"(float), The classification confidence threshold.");
AddAttr<int>("top_k", "(int), The bbox number kept of the layer’s output.");
AddAttr<int>("nms_top_k",
"(int), The bbox number kept of the NMS’s output.");
AddComment(R"DOC(
detection output for SSD(single shot multibox detector)
Apply the NMS to the output of network and compute the predict
bounding box location. The output’s shape of this layer could
be zero if there is no valid bounding box.
)DOC");
}
};
class DetectionOutputOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Loc"),
"Input(X) of DetectionOutputOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Conf"),
"Input(X) of DetectionOutputOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
"Input(X) of DetectionOutputOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of DetectionOutputOp should not be null.");
std::vector<int64_t> output_shape({1, 7});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(detection_output, ops::DetectionOutputOp,
ops::DetectionOutputOpMaker);
REGISTER_OP_CPU_KERNEL(
detection_output,
ops::DetectionOutputKernel<paddle::platform::CPUDeviceContext, float>,
ops::DetectionOutputKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/detection_output_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
detection_output,
ops::DetectionOutputKernel<paddle::platform::CUDADeviceContext, float>,
ops::DetectionOutputKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/math/detection_util.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/softmax.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
inline void transpose_fun(const framework::ExecutionContext& context,
const framework::Tensor& src,
framework::Tensor* dst) {
int input_nums = src.dims()[0];
int offset = 0;
for (int j = 0; j < input_nums; ++j) {
framework::Tensor in_p_tensor = src.Slice(j, j + 1);
std::vector<int64_t> shape_vec(
{in_p_tensor.dims()[0], in_p_tensor.dims()[1], in_p_tensor.dims()[3],
in_p_tensor.dims()[4], in_p_tensor.dims()[2]});
framework::DDim shape(framework::make_ddim(shape_vec));
framework::Tensor in_p_tensor_transpose;
in_p_tensor_transpose.mutable_data<T>(shape, context.GetPlace());
std::vector<int> shape_axis({0, 1, 3, 4, 2});
math::Transpose<DeviceContext, T, 5> trans5;
trans5(context.template device_context<DeviceContext>(), in_p_tensor,
&in_p_tensor_transpose, shape_axis);
auto dst_stride = framework::stride(dst->dims());
auto src_stride = framework::stride(in_p_tensor_transpose.dims());
StridedMemcpy<T>(context.device_context(), in_p_tensor_transpose.data<T>(),
src_stride, in_p_tensor_transpose.dims(), dst_stride,
dst->data<T>() + offset);
offset += in_p_tensor_transpose.dims()[4] * src_stride[4];
}
}
template <typename DeviceContext, typename T>
class DetectionOutputKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_loc = context.Input<framework::Tensor>("Loc");
const framework::Tensor* in_conf = context.Input<framework::Tensor>("Conf");
const framework::Tensor* in_priorbox =
context.Input<framework::Tensor>("PriorBox");
auto* out = context.Output<framework::Tensor>("Out");
int num_classes = context.template Attr<int>("num_classes");
int top_k = context.template Attr<int>("top_k");
int nms_top_k = context.template Attr<int>("nms_top_k");
int background_label_id = context.template Attr<int>("background_label_id");
float nms_threshold = context.template Attr<float>("nms_threshold");
float confidence_threshold =
context.template Attr<float>("confidence_threshold");
size_t batch_size = in_conf->dims()[1];
int conf_sum_size = in_conf->numel();
// for softmax
std::vector<int64_t> conf_shape_softmax_vec(
{conf_sum_size / num_classes, num_classes});
framework::DDim conf_shape_softmax(
framework::make_ddim(conf_shape_softmax_vec));
// for knchw => nhwc
std::vector<int64_t> loc_shape_vec({1, in_loc->dims()[1], in_loc->dims()[3],
in_loc->dims()[4],
in_loc->dims()[2] * in_loc->dims()[0]});
std::vector<int64_t> conf_shape_vec(
{1, in_conf->dims()[1], in_conf->dims()[3], in_conf->dims()[4],
in_conf->dims()[2] * in_conf->dims()[0]});
framework::DDim loc_shape(framework::make_ddim(loc_shape_vec));
framework::DDim conf_shape(framework::make_ddim(conf_shape_vec));
framework::Tensor loc_tensor;
framework::Tensor conf_tensor;
loc_tensor.mutable_data<T>(loc_shape, context.GetPlace());
conf_tensor.mutable_data<T>(conf_shape, context.GetPlace());
// for cpu
framework::Tensor loc_cpu;
framework::Tensor conf_cpu;
framework::Tensor priorbox_cpu;
const T* priorbox_data = in_priorbox->data<T>();
transpose_fun<DeviceContext, T>(context, *in_loc, &loc_tensor);
transpose_fun<DeviceContext, T>(context, *in_conf, &conf_tensor);
conf_tensor.Resize(conf_shape_softmax);
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &conf_tensor,
&conf_tensor);
T* loc_data = loc_tensor.data<T>();
T* conf_data = conf_tensor.data<T>();
if (platform::is_gpu_place(context.GetPlace())) {
loc_cpu.mutable_data<T>(loc_tensor.dims(), platform::CPUPlace());
framework::CopyFrom(loc_tensor, platform::CPUPlace(),
context.device_context(), &loc_cpu);
loc_data = loc_cpu.data<T>();
conf_cpu.mutable_data<T>(conf_tensor.dims(), platform::CPUPlace());
framework::CopyFrom(conf_tensor, platform::CPUPlace(),
context.device_context(), &conf_cpu);
conf_data = conf_cpu.data<T>();
priorbox_cpu.mutable_data<T>(in_priorbox->dims(), platform::CPUPlace());
framework::CopyFrom(*in_priorbox, platform::CPUPlace(),
context.device_context(), &priorbox_cpu);
priorbox_data = priorbox_cpu.data<T>();
}
// get decode bboxes
size_t num_priors = in_priorbox->numel() / 8;
std::vector<std::vector<operators::math::BBox<T>>> all_decoded_bboxes;
for (size_t n = 0; n < batch_size; ++n) {
std::vector<operators::math::BBox<T>> decoded_bboxes;
for (size_t i = 0; i < num_priors; ++i) {
size_t prior_offset = i * 8;
size_t loc_pred_offset = n * num_priors * 4 + i * 4;
std::vector<math::BBox<T>> prior_bbox_vec;
math::GetBBoxFromPriorData<T>(priorbox_data + prior_offset, 1,
prior_bbox_vec);
std::vector<std::vector<T>> prior_bbox_var;
math::GetBBoxVarFromPriorData<T>(priorbox_data + prior_offset, 1,
prior_bbox_var);
std::vector<T> loc_pred_data;
for (size_t j = 0; j < 4; ++j)
loc_pred_data.push_back(*(loc_data + loc_pred_offset + j));
math::BBox<T> bbox = math::DecodeBBoxWithVar<T>(
prior_bbox_vec[0], prior_bbox_var[0], loc_pred_data);
decoded_bboxes.push_back(bbox);
}
all_decoded_bboxes.push_back(decoded_bboxes);
}
std::vector<std::map<size_t, std::vector<size_t>>> all_indices;
int num_kept = math::GetDetectionIndices<T>(
conf_data, num_priors, num_classes, background_label_id, batch_size,
confidence_threshold, nms_top_k, nms_threshold, top_k,
all_decoded_bboxes, &all_indices);
if (num_kept <= 0) {
std::vector<int64_t> out_shape_vec({0, 0});
framework::DDim out_shape(framework::make_ddim(out_shape_vec));
out->Resize(out_shape);
return;
}
std::vector<int64_t> out_shape_vec({num_kept, 7});
framework::DDim out_shape(framework::make_ddim(out_shape_vec));
out->mutable_data<T>(out_shape, context.GetPlace());
framework::Tensor out_cpu;
T* out_data = out->data<T>();
if (platform::is_gpu_place(context.GetPlace())) {
out_cpu.mutable_data<T>(out->dims(), platform::CPUPlace());
out_data = out_cpu.data<T>();
}
math::GetDetectionOutput<T>(conf_data, num_kept, num_priors, num_classes,
batch_size, all_indices, all_decoded_bboxes,
out_data);
if (platform::is_gpu_place(context.GetPlace())) {
framework::CopyFrom(out_cpu, platform::CUDAPlace(),
context.device_context(), out);
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct BBox {
BBox(T x_min, T y_min, T x_max, T y_max)
: x_min(x_min),
y_min(y_min),
x_max(x_max),
y_max(y_max),
is_difficult(false) {}
BBox() {}
T get_width() const { return x_max - x_min; }
T get_height() const { return y_max - y_min; }
T get_center_x() const { return (x_min + x_max) / 2; }
T get_center_y() const { return (y_min + y_max) / 2; }
T get_area() const { return get_width() * get_height(); }
// coordinate of bounding box
T x_min;
T y_min;
T x_max;
T y_max;
// whether difficult object (e.g. object with heavy occlusion is difficult)
bool is_difficult;
};
// KNCHW ==> NHWC
// template <typename T>
template <typename T>
void GetBBoxFromPriorData(const T* prior_data, const size_t num_bboxes,
std::vector<BBox<T>>& bbox_vec);
template <typename T>
void GetBBoxVarFromPriorData(const T* prior_data, const size_t num,
std::vector<std::vector<T>>& var_vec);
template <typename T>
BBox<T> DecodeBBoxWithVar(BBox<T>& prior_bbox,
const std::vector<T>& prior_bbox_var,
const std::vector<T>& loc_pred_data);
template <typename T1, typename T2>
bool SortScorePairDescend(const std::pair<T1, T2>& pair1,
const std::pair<T1, T2>& pair2);
template <typename T>
bool SortScorePairDescend(const std::pair<T, BBox<T>>& pair1,
const std::pair<T, BBox<T>>& pair2);
template <typename T>
T jaccard_overlap(const BBox<T>& bbox1, const BBox<T>& bbox2);
template <typename T>
void ApplyNmsFast(const std::vector<BBox<T>>& bboxes, const T* conf_score_data,
size_t class_idx, size_t top_k, T conf_threshold,
T nms_threshold, size_t num_priors, size_t num_classes,
std::vector<size_t>* indices);
template <typename T>
int GetDetectionIndices(
const T* conf_data, const size_t num_priors, const size_t num_classes,
const size_t background_label_id, const size_t batch_size,
const T conf_threshold, const size_t nms_top_k, const T nms_threshold,
const size_t top_k,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes,
std::vector<std::map<size_t, std::vector<size_t>>>* all_detection_indices);
template <typename T>
BBox<T> ClipBBox(const BBox<T>& bbox);
template <typename T>
void GetDetectionOutput(
const T* conf_data, const size_t num_kept, const size_t num_priors,
const size_t num_classes, const size_t batch_size,
const std::vector<std::map<size_t, std::vector<size_t>>>& all_indices,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes, T* out_data);
template <typename T>
void GetBBoxFromPriorData(const T* prior_data, const size_t num_bboxes,
std::vector<BBox<T>>& bbox_vec) {
size_t out_offset = bbox_vec.size();
bbox_vec.resize(bbox_vec.size() + num_bboxes);
for (size_t i = 0; i < num_bboxes; ++i) {
BBox<T> bbox;
bbox.x_min = *(prior_data + i * 8);
bbox.y_min = *(prior_data + i * 8 + 1);
bbox.x_max = *(prior_data + i * 8 + 2);
bbox.y_max = *(prior_data + i * 8 + 3);
bbox_vec[out_offset + i] = bbox;
}
}
template <typename T>
void GetBBoxVarFromPriorData(const T* prior_data, const size_t num,
std::vector<std::vector<T>>& var_vec) {
size_t out_offset = var_vec.size();
var_vec.resize(var_vec.size() + num);
for (size_t i = 0; i < num; ++i) {
std::vector<T> var;
var.push_back(*(prior_data + i * 8 + 4));
var.push_back(*(prior_data + i * 8 + 5));
var.push_back(*(prior_data + i * 8 + 6));
var.push_back(*(prior_data + i * 8 + 7));
var_vec[out_offset + i] = var;
}
}
template <typename T>
BBox<T> DecodeBBoxWithVar(BBox<T>& prior_bbox,
const std::vector<T>& prior_bbox_var,
const std::vector<T>& loc_pred_data) {
T prior_bbox_width = prior_bbox.get_width();
T prior_bbox_height = prior_bbox.get_height();
T prior_bbox_center_x = prior_bbox.get_center_x();
T prior_bbox_center_y = prior_bbox.get_center_y();
T decoded_bbox_center_x =
prior_bbox_var[0] * loc_pred_data[0] * prior_bbox_width +
prior_bbox_center_x;
T decoded_bbox_center_y =
prior_bbox_var[1] * loc_pred_data[1] * prior_bbox_height +
prior_bbox_center_y;
T decoded_bbox_width =
std::exp(prior_bbox_var[2] * loc_pred_data[2]) * prior_bbox_width;
T decoded_bbox_height =
std::exp(prior_bbox_var[3] * loc_pred_data[3]) * prior_bbox_height;
BBox<T> decoded_bbox;
decoded_bbox.x_min = decoded_bbox_center_x - decoded_bbox_width / 2;
decoded_bbox.y_min = decoded_bbox_center_y - decoded_bbox_height / 2;
decoded_bbox.x_max = decoded_bbox_center_x + decoded_bbox_width / 2;
decoded_bbox.y_max = decoded_bbox_center_y + decoded_bbox_height / 2;
return decoded_bbox;
}
template <typename T1, typename T2>
bool SortScorePairDescend(const std::pair<T1, T2>& pair1,
const std::pair<T1, T2>& pair2) {
return pair1.first > pair2.first;
}
template <typename T>
T jaccard_overlap(const BBox<T>& bbox1, const BBox<T>& bbox2) {
if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min ||
bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) {
return 0.0;
} else {
T inter_x_min = std::max(bbox1.x_min, bbox2.x_min);
T inter_y_min = std::max(bbox1.y_min, bbox2.y_min);
T interX_max = std::min(bbox1.x_max, bbox2.x_max);
T interY_max = std::min(bbox1.y_max, bbox2.y_max);
T inter_width = interX_max - inter_x_min;
T inter_height = interY_max - inter_y_min;
T inter_area = inter_width * inter_height;
T bbox_area1 = bbox1.get_area();
T bbox_area2 = bbox2.get_area();
return inter_area / (bbox_area1 + bbox_area2 - inter_area);
}
}
template <typename T>
void ApplyNmsFast(const std::vector<BBox<T>>& bboxes, const T* conf_score_data,
size_t class_idx, size_t top_k, T conf_threshold,
T nms_threshold, size_t num_priors, size_t num_classes,
std::vector<size_t>* indices) {
std::vector<std::pair<T, size_t>> scores;
for (size_t i = 0; i < num_priors; ++i) {
size_t conf_offset = i * num_classes + class_idx;
if (conf_score_data[conf_offset] > conf_threshold)
scores.push_back(std::make_pair(conf_score_data[conf_offset], i));
}
std::stable_sort(scores.begin(), scores.end(),
SortScorePairDescend<T, size_t>);
if (top_k > 0 && top_k < scores.size()) scores.resize(top_k);
while (scores.size() > 0) {
const size_t idx = scores.front().second;
bool keep = true;
for (size_t i = 0; i < indices->size(); ++i) {
if (keep) {
const size_t saved_idx = (*indices)[i];
T overlap = jaccard_overlap<T>(bboxes[idx], bboxes[saved_idx]);
keep = overlap <= nms_threshold;
} else {
break;
}
}
if (keep) indices->push_back(idx);
scores.erase(scores.begin());
}
}
template <typename T>
int GetDetectionIndices(
const T* conf_data, const size_t num_priors, const size_t num_classes,
const size_t background_label_id, const size_t batch_size,
const T conf_threshold, const size_t nms_top_k, const T nms_threshold,
const size_t top_k,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes,
std::vector<std::map<size_t, std::vector<size_t>>>* all_detection_indices) {
int total_keep_num = 0;
for (size_t n = 0; n < batch_size; ++n) {
const std::vector<BBox<T>>& decoded_bboxes = all_decoded_bboxes[n];
size_t num_detected = 0;
std::map<size_t, std::vector<size_t>> indices;
size_t conf_offset = n * num_priors * num_classes;
for (size_t c = 0; c < num_classes; ++c) {
if (c == background_label_id) continue;
ApplyNmsFast<T>(decoded_bboxes, conf_data + conf_offset, c, nms_top_k,
conf_threshold, nms_threshold, num_priors, num_classes,
&(indices[c]));
num_detected += indices[c].size();
}
if (top_k > 0 && num_detected > top_k) {
// std::vector<pair<T,T>> score_index_pairs;
std::vector<std::pair<T, std::pair<size_t, size_t>>> score_index_pairs;
for (size_t c = 0; c < num_classes; ++c) {
const std::vector<size_t>& label_indices = indices[c];
for (size_t i = 0; i < label_indices.size(); ++i) {
size_t idx = label_indices[i];
score_index_pairs.push_back(
std::make_pair((conf_data + conf_offset)[idx * num_classes + c],
std::make_pair(c, idx)));
}
}
std::sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScorePairDescend<T, std::pair<size_t, size_t>>);
score_index_pairs.resize(top_k);
std::map<size_t, std::vector<size_t>> new_indices;
for (size_t i = 0; i < score_index_pairs.size(); ++i) {
size_t label = score_index_pairs[i].second.first;
size_t idx = score_index_pairs[i].second.second;
new_indices[label].push_back(idx);
}
all_detection_indices->push_back(new_indices);
total_keep_num += top_k;
} else {
all_detection_indices->push_back(indices);
total_keep_num += num_detected;
}
}
return total_keep_num;
}
template <typename T>
BBox<T> ClipBBox(const BBox<T>& bbox) {
T one = static_cast<T>(1.0);
T zero = static_cast<T>(0.0);
BBox<T> clipped_bbox;
clipped_bbox.x_min = std::max(std::min(bbox.x_min, one), zero);
clipped_bbox.y_min = std::max(std::min(bbox.y_min, one), zero);
clipped_bbox.x_max = std::max(std::min(bbox.x_max, one), zero);
clipped_bbox.y_max = std::max(std::min(bbox.y_max, one), zero);
return clipped_bbox;
}
template <typename T>
void GetDetectionOutput(
const T* conf_data, const size_t num_kept, const size_t num_priors,
const size_t num_classes, const size_t batch_size,
const std::vector<std::map<size_t, std::vector<size_t>>>& all_indices,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes, T* out_data) {
size_t count = 0;
for (size_t n = 0; n < batch_size; ++n) {
for (std::map<size_t, std::vector<size_t>>::const_iterator it =
all_indices[n].begin();
it != all_indices[n].end(); ++it) {
size_t label = it->first;
const std::vector<size_t>& indices = it->second;
const std::vector<BBox<T>>& decoded_bboxes = all_decoded_bboxes[n];
for (size_t i = 0; i < indices.size(); ++i) {
size_t idx = indices[i];
size_t conf_offset = n * num_priors * num_classes + idx * num_classes;
out_data[count * 7] = n;
out_data[count * 7 + 1] = label;
out_data[count * 7 + 2] = (conf_data + conf_offset)[label];
BBox<T> clipped_bbox = ClipBBox<T>(decoded_bboxes[idx]);
out_data[count * 7 + 3] = clipped_bbox.x_min;
out_data[count * 7 + 4] = clipped_bbox.y_min;
out_data[count * 7 + 5] = clipped_bbox.x_max;
out_data[count * 7 + 6] = clipped_bbox.y_max;
++count;
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestUnpoolOp(OpTest):
def setUp(self):
self.op_type = "detection_output"
self.init_test_case()
#loc.shape ((1, 4, 4, 1, 1))
#conf.shape ((1, 4, 2, 1, 1))
loc = np.array([[[[[0.1]], [[0.1]], [[0.1]], [[0.1]]],
[[[0.1]], [[0.1]], [[0.1]], [[0.1]]],
[[[0.1]], [[0.1]], [[0.1]], [[0.1]]],
[[[0.1]], [[0.1]], [[0.1]], [[0.1]]]]])
conf = np.array([[[[[0.1]], [[0.9]]], [[[0.2]], [[0.8]]],
[[[0.3]], [[0.7]]], [[[0.4]], [[0.6]]]]])
priorbox = np.array([
0.1, 0.1, 0.5, 0.5, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.6, 0.6, 0.1,
0.1, 0.2, 0.2, 0.3, 0.3, 0.7, 0.7, 0.1, 0.1, 0.2, 0.2, 0.4, 0.4,
0.8, 0.8, 0.1, 0.1, 0.2, 0.2
])
output = np.array([
0, 1, 0.68997443, 0.099959746, 0.099959746, 0.50804031, 0.50804031
])
self.inputs = {
'Loc': loc.astype('float32'),
'Conf': conf.astype('float32'),
'PriorBox': priorbox.astype('float32')
}
self.attrs = {
'num_classes': self.num_classes,
'top_k': self.top_k,
'nms_top_k': self.nms_top_k,
'background_label_id': self.background_label_id,
'nms_threshold': self.nms_threshold,
'confidence_threshold': self.confidence_threshold,
}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
def init_test_case(self):
self.num_classes = 2
self.top_k = 10
self.nms_top_k = 20
self.background_label_id = 0
self.nms_threshold = 0.01
self.confidence_threshold = 0.01
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册