提交 ffff1b01 编写于 作者: W wangzhe

detection_post_process op support quantized anchors tensor

上级 c6ba87d9
......@@ -33,6 +33,8 @@ typedef struct DetectionPostProcessParameter {
bool use_regular_nms_;
bool out_quantized_;
float *anchors_;
void *decoded_boxes_;
void *nms_candidate_;
void *selected_;
......
......@@ -44,15 +44,15 @@ float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) {
const float h = ymax - ymin > 0.0f ? ymax - ymin : 0.0f;
const float w = xmax - xmin > 0.0f ? xmax - xmin : 0.0f;
const float inter = h * w;
return inter / (area_a + area_b - inter + 1e-8);
return inter / (area_a + area_b - inter);
}
void DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anchors, const BboxCenter scaler,
float *decoded_boxes) {
for (int i = 0; i < num_boxes; ++i) {
BboxCenter *box = (BboxCenter *)(input_boxes + i * 4);
BboxCenter *anchor = (BboxCenter *)(anchors + i * 4);
BboxCorner *decoded_box = (BboxCorner *)(decoded_boxes + i * 4);
BboxCenter *box = (BboxCenter *)(input_boxes) + i;
BboxCenter *anchor = (BboxCenter *)(anchors) + i;
BboxCorner *decoded_box = (BboxCorner *)(decoded_boxes) + i;
float y_center = box->y / scaler.y * anchor->h + anchor->y;
float x_center = box->x / scaler.x * anchor->w + anchor->x;
float h_half = 0.5f * expf(box->h / scaler.h) * anchor->h;
......@@ -137,7 +137,7 @@ int NmsMultiClassesRegular(const int num_boxes, const int num_classes_with_bg, c
const int class_index = score_with_index_all[i].index - box_index * num_classes_with_bg - first_class_index;
*((BboxCorner *)(output_boxes) + i) = *((BboxCorner *)(decoded_boxes) + box_index);
output_classes[i] = (float)class_index;
output_scores[i] = score_with_index_all[i].score;;
output_scores[i] = score_with_index_all[i].score;
} else {
((BboxCorner *)(output_boxes) + i)->ymin = 0;
((BboxCorner *)(output_boxes) + i)->xmin = 0;
......
......@@ -26,7 +26,32 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DetectionPostProcess;
namespace mindspore::kernel {
int DetectionPostProcessCPUKernel::Init() { return RET_OK; }
int DetectionPostProcessCPUKernel::Init() {
MS_ASSERT(context_->allocator != nullptr);
auto anchor_tensor = in_tensors_.at(2);
DetectionPostProcessParameter *parameter = reinterpret_cast<DetectionPostProcessParameter *>(op_parameter_);
if (anchor_tensor->data_type() == kNumberTypeUInt8) {
const auto quant_params = anchor_tensor->GetQuantParams();
const double scale = quant_params.at(0).scale;
const int32_t zp = quant_params.at(0).zeroPoint;
auto anchor_uint8 = reinterpret_cast<uint8_t *>(anchor_tensor->Data());
auto anchor_fp32 =
reinterpret_cast<float *>(context_->allocator->Malloc(anchor_tensor->ElementsNum() * sizeof(float)));
for (int i = 0; i < anchor_tensor->ElementsNum(); ++i) {
*(anchor_fp32 + i) = static_cast<float>((static_cast<int>(anchor_uint8[i]) - zp) * scale);
}
parameter->anchors_ = anchor_fp32;
} else if (anchor_tensor->data_type() == kNumberTypeFloat32) {
auto anchor_fp32 = reinterpret_cast<float *>(anchor_tensor->Data());
for (int i = 0; i < anchor_tensor->ElementsNum(); ++i) {
parameter->anchors_[i] = anchor_fp32[i];
}
} else {
MS_LOG(ERROR) << "unsupported anchor data type " << anchor_tensor->data_type();
return RET_ERROR;
}
return RET_OK;
}
int DetectionPostProcessCPUKernel::ReSize() { return RET_OK; }
......@@ -38,7 +63,6 @@ int DetectionPostProcessCPUKernel::Run() {
}
auto input_boxes = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto input_scores = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
auto input_anchors = reinterpret_cast<float *>(in_tensors_.at(2)->Data());
// output_classes and output_num use float type now
auto output_boxes = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
......@@ -61,7 +85,7 @@ int DetectionPostProcessCPUKernel::Run() {
parameter->score_with_class_all_ =
context_->allocator->Malloc((num_boxes * parameter->num_classes_) * sizeof(ScoreWithIndex));
}
DetectionPostProcess(num_boxes, num_classes_with_bg, input_boxes, input_scores, input_anchors, output_boxes,
DetectionPostProcess(num_boxes, num_classes_with_bg, input_boxes, input_scores, parameter->anchors_, output_boxes,
output_classes, output_scores, output_num, parameter);
context_->allocator->Free(parameter->decoded_boxes_);
context_->allocator->Free(parameter->nms_candidate_);
......
......@@ -56,9 +56,13 @@ void DetectionPostProcessTestInit(std::vector<lite::tensor::Tensor *> *inputs_,
std::string input_anchors_path = "./test_data/detectionPostProcess/input_anchors.bin";
size_t input_anchors_size;
auto input_anchors_data =
reinterpret_cast<float *>(mindspore::lite::ReadFile(input_anchors_path.c_str(), &input_anchors_size));
reinterpret_cast<uint8_t *>(mindspore::lite::ReadFile(input_anchors_path.c_str(), &input_anchors_size));
auto *input_anchors = new lite::tensor::Tensor;
input_anchors->set_data_type(kNumberTypeFloat32);
lite::tensor::QuantArg quant_arg;
quant_arg.zeroPoint = 0;
quant_arg.scale = 0.00645306;
input_anchors->AddQuantParam(quant_arg);
input_anchors->set_data_type(kNumberTypeUInt8);
input_anchors->SetFormat(schema::Format_NHWC);
input_anchors->set_shape({1917, 4});
input_anchors->MallocData();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册