未验证 提交 de43e479 编写于 作者: W Wilber 提交者: GitHub

Add yolo_box_cuda multiclass_nms_host kernel. (#1908)

* add yolo_box_compute cuda

* move multiclass_nms(arm) to host

* add lod in scale op

* add yolo_box_cuda cmake config

* modify shuffle_channel_fuse and transpose_softmax_transpose_fuse to support run ssd model. test=develop

* reshape and transpose op don't have xshape output.

* modify yolo_box_compute_cuda, use tensor to manage cuda memory test=develop

* add yolo_box use kernel test=develop
上级 79714d74
...@@ -31,6 +31,7 @@ USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def); ...@@ -31,6 +31,7 @@ USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def);
// host kernels // host kernels
USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def);
USE_LITE_KERNEL(multiclass_nms, kHost, kFloat, kNCHW, def);
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
...@@ -92,7 +93,6 @@ USE_LITE_KERNEL(top_k, kARM, kFloat, kNCHW, def); ...@@ -92,7 +93,6 @@ USE_LITE_KERNEL(top_k, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(increment, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(increment, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(write_to_array, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(write_to_array, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(read_from_array, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(read_from_array, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(multiclass_nms, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(reduce_max, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(reduce_max, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(sequence_expand, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(sequence_expand, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(sequence_pool, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(sequence_pool, kARM, kFloat, kNCHW, def);
...@@ -155,6 +155,7 @@ USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); ...@@ -155,6 +155,7 @@ USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device); USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host); USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host);
USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(yolo_box, kCUDA, kFloat, kNCHW, def);
#endif #endif
#ifdef LITE_WITH_OPENCL #ifdef LITE_WITH_OPENCL
......
...@@ -56,7 +56,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -56,7 +56,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
elementwise.cc elementwise.cc
lrn.cc lrn.cc
decode_bboxes.cc decode_bboxes.cc
multiclass_nms.cc
concat.cc concat.cc
sgemv.cc sgemv.cc
type_trans.cc type_trans.cc
......
...@@ -39,7 +39,6 @@ ...@@ -39,7 +39,6 @@
#include "lite/arm/math/increment.h" #include "lite/arm/math/increment.h"
#include "lite/arm/math/interpolate.h" #include "lite/arm/math/interpolate.h"
#include "lite/arm/math/lrn.h" #include "lite/arm/math/lrn.h"
#include "lite/arm/math/multiclass_nms.h"
#include "lite/arm/math/negative.h" #include "lite/arm/math/negative.h"
#include "lite/arm/math/norm.h" #include "lite/arm/math/norm.h"
#include "lite/arm/math/packed_sgemm.h" #include "lite/arm/math/packed_sgemm.h"
......
...@@ -28,12 +28,17 @@ void ShuffleChannelFuser::BuildPattern() { ...@@ -28,12 +28,17 @@ void ShuffleChannelFuser::BuildPattern() {
auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out"); auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out");
auto* xshape1 = PMNode* xshape1 = nullptr;
VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); PMNode* xshape2 = nullptr;
auto* xshape2 = PMNode* xshape3 = nullptr;
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); if (reshape_type_ == "reshape2") {
auto* xshape3 = xshape1 = VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape");
VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); xshape3 = VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape");
}
if (transpose_type_ == "transpose2") {
xshape2 =
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
}
auto* reshape1 = OpNode("reshape1", reshape_type_) auto* reshape1 = OpNode("reshape1", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>( ->assert_op_attr_satisfied<std::vector<int>>(
...@@ -54,16 +59,16 @@ void ShuffleChannelFuser::BuildPattern() { ...@@ -54,16 +59,16 @@ void ShuffleChannelFuser::BuildPattern() {
// create topology. // create topology.
*x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out; *x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out;
*reshape1 >> *xshape1; if (xshape1) *reshape1 >> *xshape1;
*transpose >> *xshape2; if (xshape2) *transpose >> *xshape2;
*reshape2 >> *xshape3; if (xshape3) *reshape2 >> *xshape3;
// Some op specialities. // Some op specialities.
y1->AsIntermediate(); y1->AsIntermediate();
y2->AsIntermediate(); y2->AsIntermediate();
xshape1->AsIntermediate(); if (xshape1) xshape1->AsIntermediate();
xshape2->AsIntermediate(); if (xshape2) xshape2->AsIntermediate();
xshape3->AsIntermediate(); if (xshape3) xshape3->AsIntermediate();
reshape1->AsIntermediate(); reshape1->AsIntermediate();
transpose->AsIntermediate(); transpose->AsIntermediate();
reshape2->AsIntermediate(); reshape2->AsIntermediate();
......
...@@ -28,10 +28,14 @@ void TransposeSoftmaxTransposeFuser::BuildPattern() { ...@@ -28,10 +28,14 @@ void TransposeSoftmaxTransposeFuser::BuildPattern() {
auto* y2 = VarNode("y2")->assert_is_op_output(softmax_type_, "Out"); auto* y2 = VarNode("y2")->assert_is_op_output(softmax_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(transpose_type_, "Out"); auto* out = VarNode("out")->assert_is_op_output(transpose_type_, "Out");
auto* xshape1 = PMNode* xshape1 = nullptr;
VarNode("xshape1")->assert_is_op_output(transpose_type_, "XShape"); PMNode* xshape2 = nullptr;
auto* xshape2 = if (transpose_type_ == "transpose2") {
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); xshape1 =
VarNode("xshape1")->assert_is_op_output(transpose_type_, "XShape");
xshape2 =
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
}
auto* transpose1 = auto* transpose1 =
OpNode("transpose1", transpose_type_)->assert_is_op(transpose_type_); OpNode("transpose1", transpose_type_)->assert_is_op(transpose_type_);
...@@ -45,14 +49,14 @@ void TransposeSoftmaxTransposeFuser::BuildPattern() { ...@@ -45,14 +49,14 @@ void TransposeSoftmaxTransposeFuser::BuildPattern() {
// create topology. // create topology.
*x1 >> *transpose1 >> *y1 >> *softmax >> *y2 >> *transpose2 >> *out; *x1 >> *transpose1 >> *y1 >> *softmax >> *y2 >> *transpose2 >> *out;
*transpose1 >> *xshape1; if (xshape1) *transpose1 >> *xshape1;
*transpose2 >> *xshape2; if (xshape2) *transpose2 >> *xshape2;
// nodes to remove // nodes to remove
y1->AsIntermediate(); y1->AsIntermediate();
y2->AsIntermediate(); y2->AsIntermediate();
xshape1->AsIntermediate(); if (xshape1) xshape1->AsIntermediate();
xshape2->AsIntermediate(); if (xshape2) xshape2->AsIntermediate();
transpose1->AsIntermediate(); transpose1->AsIntermediate();
softmax->AsIntermediate(); softmax->AsIntermediate();
transpose2->AsIntermediate(); transpose2->AsIntermediate();
......
...@@ -15,7 +15,6 @@ add_kernel(batch_norm_compute_arm ARM basic SRCS batch_norm_compute.cc DEPS ${li ...@@ -15,7 +15,6 @@ add_kernel(batch_norm_compute_arm ARM basic SRCS batch_norm_compute.cc DEPS ${li
add_kernel(elementwise_compute_arm ARM basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(elementwise_compute_arm ARM basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lrn_compute_arm ARM basic SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lrn_compute_arm ARM basic SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(decode_bboxes_compute_arm ARM basic SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(decode_bboxes_compute_arm ARM basic SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(multiclass_nms_compute_arm ARM basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(pool_compute_arm ARM basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(pool_compute_arm ARM basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(split_compute_arm ARM basic SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(split_compute_arm ARM basic SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(concat_compute_arm ARM basic SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(concat_compute_arm ARM basic SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -80,7 +79,6 @@ lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS ba ...@@ -80,7 +79,6 @@ lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS ba
lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm) lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm)
lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm) lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm)
lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm) lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm)
lite_cc_test(test_multiclass_nms_compute_arm SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_arm)
lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm)
lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm)
lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/multiclass_nms_compute.h"
#include <string>
#include "lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void MulticlassNmsCompute::Run() {
auto& param = Param<operators::MulticlassNmsParam>();
// bbox shape : N, M, 4
// scores shape : N, C, M
const float* bbox_data = param.bbox_data->data<float>();
const float* conf_data = param.conf_data->data<float>();
CHECK_EQ(param.bbox_data->dims().production() % 4, 0);
std::vector<float> result;
int N = param.bbox_data->dims()[0];
int M = param.bbox_data->dims()[1];
std::vector<int> priors(N, M);
int class_num = param.conf_data->dims()[1];
int background_label = param.background_label;
int keep_top_k = param.keep_top_k;
int nms_top_k = param.nms_top_k;
float score_threshold = param.score_threshold;
float nms_threshold = param.nms_threshold;
float nms_eta = param.nms_eta;
bool share_location = param.share_location;
lite::arm::math::multiclass_nms(bbox_data,
conf_data,
&result,
priors,
class_num,
background_label,
keep_top_k,
nms_top_k,
score_threshold,
nms_threshold,
nms_eta,
share_location);
lite::LoD* lod = param.out->mutable_lod();
std::vector<uint64_t> lod_info;
lod_info.push_back(0);
std::vector<float> result_corrected;
int tmp_batch_id;
uint64_t num = 0;
for (int i = 0; i < result.size(); ++i) {
if (i == 0) {
tmp_batch_id = result[i];
}
if (i % 7 == 0) {
if (result[i] == tmp_batch_id) {
++num;
} else {
lod_info.push_back(num);
++num;
tmp_batch_id = result[i];
}
} else {
result_corrected.push_back(result[i]);
}
}
lod_info.push_back(num);
(*lod).push_back(lod_info);
if (result_corrected.empty()) {
(*lod).clear();
(*lod).push_back(std::vector<uint64_t>({0, 1}));
param.out->Resize({static_cast<int64_t>(1)});
param.out->mutable_data<float>()[0] = -1.;
} else {
param.out->Resize({static_cast<int64_t>(result_corrected.size() / 6), 6});
float* out = param.out->mutable_data<float>();
std::memcpy(
out, result_corrected.data(), sizeof(float) * result_corrected.size());
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(multiclass_nms,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::MulticlassNmsCompute,
def)
.BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Scores", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
...@@ -32,6 +32,9 @@ void ScaleCompute::Run() { ...@@ -32,6 +32,9 @@ void ScaleCompute::Run() {
bias *= scale; bias *= scale;
} }
lite::arm::math::scale(x_data, output_data, x_dims.production(), scale, bias); lite::arm::math::scale(x_data, output_data, x_dims.production(), scale, bias);
if (!param.x->lod().empty()) {
param.output->set_lod(param.x->lod());
}
} }
} // namespace arm } // namespace arm
......
...@@ -7,13 +7,15 @@ message(STATUS "compile with lite CUDA kernels") ...@@ -7,13 +7,15 @@ message(STATUS "compile with lite CUDA kernels")
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) lite_cc_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
set(cuda_kernels set(cuda_kernels
mul_compute_cuda mul_compute_cuda
io_copy_compute_cuda io_copy_compute_cuda
leaky_relu_compute_cuda leaky_relu_compute_cuda
yolo_box_compute_cuda
) )
set(cuda_kernels "${cuda_kernels}" CACHE GLOBAL "cuda kernels") set(cuda_kernels "${cuda_kernels}" CACHE GLOBAL "cuda kernels")
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/yolo_box_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
__host__ __device__ inline int GetEntryIndex(int batch,
int an_idx,
int hw_idx,
int an_num,
int an_stride,
int stride,
int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
template <typename T>
__host__ __device__ inline T sigmoid(T x) {
return 1.0 / (1.0 + std::exp(-x));
}
template <typename T>
__host__ __device__ inline void GetYoloBox(T* box,
const T* x,
const int* anchors,
int i,
int j,
int an_idx,
int grid_size,
int input_size,
int index,
int stride,
int img_height,
int img_width) {
box[0] = (i + sigmoid<T>(x[index])) * img_width / grid_size;
box[1] = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size;
}
template <typename T>
__host__ __device__ inline void CalcDetectionBox(T* boxes,
T* box,
const int box_idx,
const int img_height,
const int img_width) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<T>(img_height - 1);
}
template <typename T>
__host__ __device__ inline void CalcLabelScore(T* scores,
const T* input,
const int label_idx,
const int score_idx,
const int class_num,
const T conf,
const int stride) {
for (int i = 0; i < class_num; ++i) {
scores[score_idx + i] = conf * sigmoid<T>(input[label_idx + i * stride]);
}
}
template <typename T>
__global__ void KeYoloBoxFw(const T* input,
const T* imgsize,
T* boxes,
T* scores,
const float conf_thresh,
const int* anchors,
const int n,
const int h,
const int w,
const int an_num,
const int class_num,
const int box_num,
int input_size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
for (; tid < n * box_num; tid += stride) {
int grid_num = h * w;
int i = tid / box_num;
int j = (tid % box_num) / grid_num;
int k = (tid % grid_num) / w;
int l = tid % w;
int an_stride = (5 + class_num) * grid_num;
int img_height = static_cast<int>(imgsize[2 * i]);
int img_width = static_cast<int>(imgsize[2 * i + 1]);
int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4);
T conf = sigmoid<T>(input[obj_idx]);
if (conf < conf_thresh) {
continue;
}
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
GetYoloBox<T>(box,
input,
anchors,
l,
k,
j,
h,
input_size,
box_idx,
grid_num,
img_height,
img_width);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(
scores, input, label_idx, score_idx, class_num, conf, grid_num);
}
}
void YoloBoxCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
lite::Tensor* X = param.X;
lite::Tensor* ImgSize = param.ImgSize;
lite::Tensor* Boxes = param.Boxes;
lite::Tensor* Scores = param.Scores;
std::vector<int> anchors = param.anchors;
int class_num = param.class_num;
float conf_thresh = param.conf_thresh;
int downsample_ratio = param.downsample_ratio;
const float* input = X->data<float>();
const float* imgsize = ImgSize->data<float>();
float* boxes = Boxes->mutable_data<float>(TARGET(kCUDA));
float* scores = Scores->mutable_data<float>(TARGET(kCUDA));
const int n = X->dims()[0];
const int h = X->dims()[2];
const int w = X->dims()[3];
const int box_num = Boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
anchors_.Resize(static_cast<int>({anchors.size()}));
int* d_anchors = anchors_.mutable_data<int>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(d_anchors,
anchors.data(),
sizeof(int) * anchors.size(),
IoDirection::HtoD);
int threads = 512;
int blocks = (n * box_num + threads - 1) / threads;
blocks = blocks > 8 ? 8 : blocks;
KeYoloBoxFw<float><<<blocks, threads, 0, stream>>>(input,
imgsize,
boxes,
scores,
conf_thresh,
d_anchors,
n,
h,
w,
an_num,
class_num,
box_num,
input_size);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(yolo_box,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::YoloBoxCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("ImgSize", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
...@@ -13,33 +13,25 @@ ...@@ -13,33 +13,25 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/kernel.h"
#include <algorithm>
#include <map>
#include <string>
#include <utility>
#include <vector>
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace kernels {
namespace math { namespace cuda {
class YoloBoxCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::YoloBoxParam;
void Run() override;
virtual ~YoloBoxCompute() = default;
template <typename dtype> private:
void multiclass_nms(const dtype* bbox_cpu_data, lite::Tensor anchors_;
const dtype* conf_cpu_data, };
std::vector<dtype>* result,
const std::vector<int>& priors,
int class_num,
int background_id,
int keep_topk,
int nms_topk,
float conf_thresh,
float nms_thresh,
float nms_eta,
bool share_location);
} // namespace math } // namespace cuda
} // namespace arm } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/yolo_box_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
inline static float sigmoid(float x) { return 1.f / (1.f + expf(-x)); }
inline static void get_yolo_box(float* box,
const float* x,
const int* anchors,
int i,
int j,
int an_idx,
int grid_size,
int input_size,
int index,
int stride,
int img_height,
int img_width) {
box[0] = (i + sigmoid(x[index])) * img_width / grid_size;
box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size;
}
inline static int get_entry_index(int batch,
int an_idx,
int hw_idx,
int an_num,
int an_stride,
int stride,
int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
inline static void calc_detection_box(float* boxes,
float* box,
const int box_idx,
const int img_height,
const int img_width) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<float>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<float>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<float>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<float>(img_height - 1);
}
inline static void calc_label_score(float* scores,
const float* input,
const int label_idx,
const int score_idx,
const int class_num,
const float conf,
const int stride) {
for (int i = 0; i < class_num; i++) {
scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]);
}
}
template <typename T>
static void YoloBoxRef(const T* input,
const T* imgsize,
T* boxes,
T* scores,
const float conf_thresh,
const int* anchors,
const int n,
const int h,
const int w,
const int an_num,
const int class_num,
const int box_num,
int input_size) {
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
float box[4];
for (int i = 0; i < n; i++) {
int img_height = static_cast<int>(imgsize[2 * i]);
int img_width = static_cast<int>(imgsize[2 * i + 1]);
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int obj_idx =
get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 4);
float conf = sigmoid(input[obj_idx]);
if (conf < conf_thresh) {
continue;
}
int box_idx =
get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 0);
get_yolo_box(box,
input,
anchors,
l,
k,
j,
h,
input_size,
box_idx,
stride,
img_height,
img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
calc_detection_box(boxes, box, box_idx, img_height, img_width);
int label_idx =
get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 5);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
calc_label_score(
scores, input, label_idx, score_idx, class_num, conf, stride);
}
}
}
}
}
TEST(yolo_box, normal) {
YoloBoxCompute yolo_box_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::YoloBoxParam param;
lite::Tensor x, sz, x_cpu, sz_cpu;
lite::Tensor boxes, scores, boxes_cpu, scores_cpu;
lite::Tensor x_ref, sz_ref, boxes_ref, scores_ref;
int s = 3, cls = 4;
int n = 1, c = s * (5 + cls), h = 16, w = 16;
param.anchors = {2, 3, 4, 5, 8, 10};
param.downsample_ratio = 2;
param.conf_thresh = 0.5;
param.class_num = cls;
int m = h * w * param.anchors.size() / 2;
x.Resize({n, c, h, w});
sz.Resize({1, 2});
boxes.Resize({n, m, 4});
scores.Resize({n, cls, m});
x_cpu.Resize({n, c, h, w});
sz_cpu.Resize({1, 2});
boxes_cpu.Resize({n, m, 4});
scores_cpu.Resize({n, cls, m});
x_ref.Resize({n, c, h, w});
sz_ref.Resize({1, 2});
boxes_ref.Resize({n, m, 4});
scores_ref.Resize({n, cls, m});
auto* x_data = x.mutable_data<float>(TARGET(kCUDA));
auto* sz_data = sz.mutable_data<float>(TARGET(kCUDA));
auto* boxes_data = boxes.mutable_data<float>(TARGET(kCUDA));
auto* scores_data = scores.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* sz_cpu_data = sz_cpu.mutable_data<float>();
float* boxes_cpu_data = boxes_cpu.mutable_data<float>();
float* scores_cpu_data = scores_cpu.mutable_data<float>();
float* x_ref_data = x_ref.mutable_data<float>();
float* sz_ref_data = sz_ref.mutable_data<float>();
float* boxes_ref_data = boxes_ref.mutable_data<float>();
float* scores_ref_data = scores_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i - 5.0;
x_ref_data[i] = i - 5.0;
}
sz_cpu_data[0] = 16;
sz_cpu_data[1] = 32;
sz_ref_data[0] = 16;
sz_ref_data[1] = 32;
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
sz.Assign<float, lite::DDim, TARGET(kCUDA)>(sz_cpu_data, sz_cpu.dims());
param.X = &x;
param.ImgSize = &sz;
param.Boxes = &boxes;
param.Scores = &scores;
yolo_box_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
yolo_box_kernel.SetContext(std::move(ctx));
yolo_box_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(boxes_cpu_data,
boxes_data,
sizeof(float) * boxes.numel(),
IoDirection::DtoH);
CopySync<TARGET(kCUDA)>(scores_cpu_data,
scores_data,
sizeof(float) * scores.numel(),
IoDirection::DtoH);
YoloBoxRef<float>(x_ref_data,
sz_ref_data,
boxes_ref_data,
scores_ref_data,
param.conf_thresh,
param.anchors.data(),
n,
h,
w,
param.anchors.size() / 2,
cls,
m,
param.downsample_ratio * h);
for (int i = 0; i < boxes.numel(); i++) {
EXPECT_NEAR(boxes_cpu_data[i], boxes_ref_data[i], 1e-5);
}
for (int i = 0; i < scores.numel(); i++) {
EXPECT_NEAR(scores_cpu_data[i], scores_ref_data[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -3,5 +3,7 @@ message(STATUS "compile with lite host kernels") ...@@ -3,5 +3,7 @@ message(STATUS "compile with lite host kernels")
add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any)
lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any)
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/arm/math/multiclass_nms.h" #include "lite/kernels/host/multiclass_nms_compute.h"
#include "lite/arm/math/funcs.h" #include <map>
#include <utility>
#include <vector>
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace kernels {
namespace math { namespace host {
template <typename dtype> template <typename dtype>
static bool sort_score_pair_descend(const std::pair<float, dtype>& pair1, static bool sort_score_pair_descend(const std::pair<float, dtype>& pair1,
...@@ -269,31 +271,92 @@ void multiclass_nms(const dtype* bbox_cpu_data, ...@@ -269,31 +271,92 @@ void multiclass_nms(const dtype* bbox_cpu_data,
} }
} }
template float jaccard_overlap(const float* bbox1, const float* bbox2); void MulticlassNmsCompute::Run() {
auto& param = Param<operators::MulticlassNmsParam>();
template void apply_nms_fast(const float* bboxes, // bbox shape : N, M, 4
const float* scores, // scores shape : N, C, M
int num, const float* bbox_data = param.bbox_data->data<float>();
float score_threshold, const float* conf_data = param.conf_data->data<float>();
float nms_threshold,
float eta, CHECK_EQ(param.bbox_data->dims().production() % 4, 0);
int top_k,
std::vector<int>* indices); std::vector<float> result;
int N = param.bbox_data->dims()[0];
template void multiclass_nms(const float* bbox_cpu_data, int M = param.bbox_data->dims()[1];
const float* conf_cpu_data, std::vector<int> priors(N, M);
std::vector<float>* result, int class_num = param.conf_data->dims()[1];
const std::vector<int>& priors, int background_label = param.background_label;
int class_num, int keep_top_k = param.keep_top_k;
int background_id, int nms_top_k = param.nms_top_k;
int keep_topk, float score_threshold = param.score_threshold;
int nms_topk, float nms_threshold = param.nms_threshold;
float conf_thresh, float nms_eta = param.nms_eta;
float nms_thresh, bool share_location = param.share_location;
float nms_eta,
bool share_location); multiclass_nms(bbox_data,
conf_data,
} // namespace math &result,
} // namespace arm priors,
class_num,
background_label,
keep_top_k,
nms_top_k,
score_threshold,
nms_threshold,
nms_eta,
share_location);
lite::LoD lod;
std::vector<uint64_t> lod_info;
lod_info.push_back(0);
std::vector<float> result_corrected;
int tmp_batch_id;
uint64_t num = 0;
for (int i = 0; i < result.size(); ++i) {
if (i == 0) {
tmp_batch_id = result[i];
}
if (i % 7 == 0) {
if (result[i] == tmp_batch_id) {
++num;
} else {
lod_info.push_back(num);
++num;
tmp_batch_id = result[i];
}
} else {
result_corrected.push_back(result[i]);
}
}
lod_info.push_back(num);
lod.push_back(lod_info);
if (result_corrected.empty()) {
lod.clear();
lod.push_back(std::vector<uint64_t>({0, 1}));
param.out->Resize({static_cast<int64_t>(1)});
param.out->mutable_data<float>()[0] = -1.;
param.out->set_lod(lod);
} else {
param.out->Resize({static_cast<int64_t>(result_corrected.size() / 6), 6});
float* out = param.out->mutable_data<float>();
std::memcpy(
out, result_corrected.data(), sizeof(float) * result_corrected.size());
param.out->set_lod(lod);
}
}
} // namespace host
} // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(multiclass_nms,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::MulticlassNmsCompute,
def)
.BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
...@@ -13,26 +13,24 @@ ...@@ -13,26 +13,24 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <vector> #include <algorithm>
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
class MulticlassNmsCompute class MulticlassNmsCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> { : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public: public:
using param_t = operators::MulticlassNmsParam;
void Run() override; void Run() override;
virtual ~MulticlassNmsCompute() = default; virtual ~MulticlassNmsCompute() = default;
}; };
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -12,19 +12,16 @@ ...@@ -12,19 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/multiclass_nms_compute.h" #include "lite/kernels/host/multiclass_nms_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm>
#include <map> #include <map>
#include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
template <typename dtype> template <typename dtype>
static bool sort_score_pair_descend(const std::pair<float, dtype>& pair1, static bool sort_score_pair_descend(const std::pair<float, dtype>& pair1,
...@@ -279,21 +276,21 @@ void multiclass_nms_compute_ref(const operators::MulticlassNmsParam& param, ...@@ -279,21 +276,21 @@ void multiclass_nms_compute_ref(const operators::MulticlassNmsParam& param,
} }
} }
TEST(multiclass_nms_arm, retrive_op) { TEST(multiclass_nms_host, init) {
MulticlassNmsCompute multiclass_nms;
ASSERT_EQ(multiclass_nms.precision(), PRECISION(kFloat));
ASSERT_EQ(multiclass_nms.target(), TARGET(kHost));
}
TEST(multiclass_nms_host, retrive_op) {
auto multiclass_nms = auto multiclass_nms =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>( KernelRegistry::Global().Create<TARGET(kHost), PRECISION(kFloat)>(
"multiclass_nms"); "multiclass_nms");
ASSERT_FALSE(multiclass_nms.empty()); ASSERT_FALSE(multiclass_nms.empty());
ASSERT_TRUE(multiclass_nms.front()); ASSERT_TRUE(multiclass_nms.front());
} }
TEST(multiclass_nms_arm, init) { TEST(multiclass_nms_host, compute) {
MulticlassNmsCompute multiclass_nms;
ASSERT_EQ(multiclass_nms.precision(), PRECISION(kFloat));
ASSERT_EQ(multiclass_nms.target(), TARGET(kARM));
}
TEST(multiclass_nms_arm, compute) {
MulticlassNmsCompute multiclass_nms; MulticlassNmsCompute multiclass_nms;
operators::MulticlassNmsParam param; operators::MulticlassNmsParam param;
lite::Tensor bbox, conf, out; lite::Tensor bbox, conf, out;
...@@ -306,9 +303,6 @@ TEST(multiclass_nms_arm, compute) { ...@@ -306,9 +303,6 @@ TEST(multiclass_nms_arm, compute) {
DDim* bbox_dim; DDim* bbox_dim;
DDim* conf_dim; DDim* conf_dim;
int M = priors[0]; int M = priors[0];
// for (int i = 0; i < priors.size(); ++i) {
// M += priors[i];
//}
if (share_location) { if (share_location) {
bbox_dim = new DDim({N, M, 4}); bbox_dim = new DDim({N, M, 4});
} else { } else {
...@@ -368,9 +362,9 @@ TEST(multiclass_nms_arm, compute) { ...@@ -368,9 +362,9 @@ TEST(multiclass_nms_arm, compute) {
} }
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(multiclass_nms, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(multiclass_nms, kHost, kFloat, kNCHW, def);
...@@ -33,7 +33,6 @@ endif() ...@@ -33,7 +33,6 @@ endif()
lite_cc_test(test_sgemm SRCS test_sgemm.cc DEPS ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_sgemm SRCS test_sgemm.cc DEPS ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_multiclass_nms_compute SRCS multiclass_nms_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_bilinear_interp_compute SRCS bilinear_interp_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_bilinear_interp_compute SRCS bilinear_interp_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_nearest_interp_compute SRCS nearest_interp_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_nearest_interp_compute SRCS nearest_interp_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
bool read_file(std::vector<float>* result, const char* file_name) {
std::ifstream infile(file_name);
if (!infile.good()) {
std::cout << "Cannot open " << file_name << std::endl;
return false;
}
LOG(INFO) << "found filename: " << file_name;
std::string line;
while (std::getline(infile, line)) {
(*result).push_back(static_cast<float>(atof(line.c_str())));
}
return true;
}
const char* bboxes_file = "multiclass_nms_bboxes_file.txt";
const char* scores_file = "multiclass_nms_scores_file.txt";
const char* out_file = "multiclass_nms_out_file.txt";
namespace paddle {
namespace lite {
class MulticlassNmsComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string bbox_ = "BBoxes";
std::string conf_ = "Scores";
std::string out_ = "Out";
std::vector<int> priors_;
int class_num_;
int background_id_;
int keep_topk_;
int nms_topk_;
float conf_thresh_;
float nms_thresh_;
float nms_eta_;
bool share_location_;
DDim bbox_dims_;
DDim conf_dims_;
public:
MulticlassNmsComputeTester(const Place& place,
const std::string& alias,
std::vector<int> priors,
int class_num,
int background_id,
int keep_topk,
int nms_topk,
float conf_thresh,
float nms_thresh,
float nms_eta,
bool share_location,
DDim bbox_dims,
DDim conf_dims)
: TestCase(place, alias),
priors_(priors),
class_num_(class_num),
background_id_(background_id),
keep_topk_(keep_topk),
nms_topk_(nms_topk),
conf_thresh_(conf_thresh),
nms_thresh_(nms_thresh),
nms_eta_(nms_eta),
share_location_(share_location),
bbox_dims_(bbox_dims),
conf_dims_(conf_dims) {}
void RunBaseline(Scope* scope) override {
std::vector<float> vbbox;
std::vector<float> vscore;
std::vector<float> vout;
if (!read_file(&vout, out_file)) {
LOG(ERROR) << "load ground truth failed";
return;
}
auto* out = scope->NewTensor(out_);
CHECK(out);
out->Resize(DDim({static_cast<int64_t>(vout.size() / 6), 6}));
auto* out_data = out->mutable_data<float>();
memcpy(out_data, vout.data(), vout.size() * sizeof(float));
out->mutable_lod()->push_back(std::vector<uint64_t>({0, 10}));
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("multiclass_nms");
op_desc->SetInput("BBoxes", {bbox_});
op_desc->SetInput("Scores", {conf_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("background_label", background_id_);
op_desc->SetAttr("keep_top_k", keep_topk_);
op_desc->SetAttr("nms_top_k", nms_topk_);
op_desc->SetAttr("score_threshold", conf_thresh_);
op_desc->SetAttr("nms_threshold", nms_thresh_);
op_desc->SetAttr("nms_eta", nms_eta_);
op_desc->SetAttr("share_location", share_location_);
}
void PrepareData() override {
std::vector<float> bbox_data;
std::vector<float> conf_data;
if (!read_file(&bbox_data, bboxes_file)) {
LOG(ERROR) << "load bbox file failed";
return;
}
if (!read_file(&conf_data, scores_file)) {
LOG(ERROR) << "load score file failed";
return;
}
SetCommonTensor(bbox_, bbox_dims_, bbox_data.data());
SetCommonTensor(conf_, conf_dims_, conf_data.data());
}
};
void test_multiclass_nms(Place place) {
int keep_top_k = 200;
int nms_top_k = 400;
float nms_eta = 1.;
float score_threshold = 0.009999999776482582;
int background_label = 0;
float nms_threshold = 0.44999998807907104;
int N = 1;
int M = 1917;
int class_num = 21;
bool share_location = true;
std::vector<int> priors(N, M);
std::unique_ptr<arena::TestCase> tester(
new MulticlassNmsComputeTester(place,
"def",
priors,
class_num,
background_label,
keep_top_k,
nms_top_k,
score_threshold,
nms_threshold,
nms_eta,
share_location,
DDim({N, M, 4}),
DDim({N, class_num, M})));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
TEST(MulticlassNms, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_multiclass_nms(place);
#endif
}
} // namespace lite
} // namespace paddle
...@@ -524,22 +524,6 @@ function build_npu { ...@@ -524,22 +524,6 @@ function build_npu {
fi fi
} }
function __prepare_multiclass_nms_test_files {
local port=$1
local adb_work_dir="/data/local/tmp"
wget --no-check-certificate https://raw.githubusercontent.com/jiweibo/TestData/master/multiclass_nms_bboxes_file.txt \
-O lite/tests/kernels/multiclass_nms_bboxes_file.txt
wget --no-check-certificate https://raw.githubusercontent.com/jiweibo/TestData/master/multiclass_nms_scores_file.txt \
-O lite/tests/kernels/multiclass_nms_scores_file.txt
wget --no-check-certificate https://raw.githubusercontent.com/jiweibo/TestData/master/multiclass_nms_out_file.txt \
-O lite/tests/kernels/multiclass_nms_out_file.txt
adb -s emulator-${port} push lite/tests/kernels/multiclass_nms_bboxes_file.txt ${adb_work_dir}
adb -s emulator-${port} push lite/tests/kernels/multiclass_nms_scores_file.txt ${adb_work_dir}
adb -s emulator-${port} push lite/tests/kernels/multiclass_nms_out_file.txt ${adb_work_dir}
}
# $1: ARM_TARGET_OS in "android" , "armlinux" # $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf" # $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang" # $3: ARM_TARGET_LANG in "gcc" "clang"
...@@ -562,9 +546,6 @@ function test_arm { ...@@ -562,9 +546,6 @@ function test_arm {
return 0 return 0
fi fi
echo "prepare multiclass_nms_test files..."
__prepare_multiclass_nms_test_files $port
# prepare for CXXApi test # prepare for CXXApi test
local adb="adb -s emulator-${port}" local adb="adb -s emulator-${port}"
$adb shell mkdir -p /data/local/tmp/lite_naive_model_opt $adb shell mkdir -p /data/local/tmp/lite_naive_model_opt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册