diff --git a/lite/api/_paddle_use_kernels.h b/lite/api/_paddle_use_kernels.h index fe522592370c778aacb6fc4b7c4668ef80e62728..d54caa83e1226e838bd2aca78ae7218f9c2fb664 100644 --- a/lite/api/_paddle_use_kernels.h +++ b/lite/api/_paddle_use_kernels.h @@ -31,6 +31,7 @@ USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def); // host kernels USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); +USE_LITE_KERNEL(multiclass_nms, kHost, kFloat, kNCHW, def); #ifdef LITE_WITH_ARM USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); @@ -92,7 +93,6 @@ USE_LITE_KERNEL(top_k, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(increment, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(write_to_array, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(read_from_array, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(multiclass_nms, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(reduce_max, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(sequence_expand, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(sequence_pool, kARM, kFloat, kNCHW, def); @@ -155,6 +155,7 @@ USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device); USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host); USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(yolo_box, kCUDA, kFloat, kNCHW, def); #endif #ifdef LITE_WITH_OPENCL diff --git a/lite/arm/math/CMakeLists.txt b/lite/arm/math/CMakeLists.txt index e228259da39f8f100dab0bc1d97732a03d88b431..f17928cc2935089fd4d9925d9791f0190f1e5c85 100644 --- a/lite/arm/math/CMakeLists.txt +++ b/lite/arm/math/CMakeLists.txt @@ -56,7 +56,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR) elementwise.cc lrn.cc decode_bboxes.cc - multiclass_nms.cc concat.cc sgemv.cc type_trans.cc diff --git a/lite/arm/math/funcs.h b/lite/arm/math/funcs.h index 9b4a1ca9726eeab1e97088d4a930e6d7fbf07b52..18a9d950411ad99124e4ec323cd88bb5dff391fa 100644 --- a/lite/arm/math/funcs.h +++ b/lite/arm/math/funcs.h @@ -39,7 +39,6 @@ #include "lite/arm/math/increment.h" #include "lite/arm/math/interpolate.h" #include "lite/arm/math/lrn.h" -#include "lite/arm/math/multiclass_nms.h" #include "lite/arm/math/negative.h" #include "lite/arm/math/norm.h" #include "lite/arm/math/packed_sgemm.h" diff --git a/lite/core/mir/fusion/shuffle_channel_fuser.cc b/lite/core/mir/fusion/shuffle_channel_fuser.cc index 01a091577a458f35566f4ae5b73f9fc4dd18e89d..f0087f8991b6b4457da29db0feac30c6bf9e722e 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuser.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuser.cc @@ -28,12 +28,17 @@ void ShuffleChannelFuser::BuildPattern() { auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out"); auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); - auto* xshape1 = - VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); - auto* xshape2 = - VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); - auto* xshape3 = - VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); + PMNode* xshape1 = nullptr; + PMNode* xshape2 = nullptr; + PMNode* xshape3 = nullptr; + if (reshape_type_ == "reshape2") { + xshape1 = VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); + xshape3 = VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); + } + if (transpose_type_ == "transpose2") { + xshape2 = + VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); + } auto* reshape1 = OpNode("reshape1", reshape_type_) ->assert_op_attr_satisfied>( @@ -54,16 +59,16 @@ void ShuffleChannelFuser::BuildPattern() { // create topology. *x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out; - *reshape1 >> *xshape1; - *transpose >> *xshape2; - *reshape2 >> *xshape3; + if (xshape1) *reshape1 >> *xshape1; + if (xshape2) *transpose >> *xshape2; + if (xshape3) *reshape2 >> *xshape3; // Some op specialities. y1->AsIntermediate(); y2->AsIntermediate(); - xshape1->AsIntermediate(); - xshape2->AsIntermediate(); - xshape3->AsIntermediate(); + if (xshape1) xshape1->AsIntermediate(); + if (xshape2) xshape2->AsIntermediate(); + if (xshape3) xshape3->AsIntermediate(); reshape1->AsIntermediate(); transpose->AsIntermediate(); reshape2->AsIntermediate(); diff --git a/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc b/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc index 5e559994428dda67c05f7d14b31de5418826b449..d578b725ec42c926e5f0581fd8eeef855e586bdc 100644 --- a/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc +++ b/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc @@ -28,10 +28,14 @@ void TransposeSoftmaxTransposeFuser::BuildPattern() { auto* y2 = VarNode("y2")->assert_is_op_output(softmax_type_, "Out"); auto* out = VarNode("out")->assert_is_op_output(transpose_type_, "Out"); - auto* xshape1 = - VarNode("xshape1")->assert_is_op_output(transpose_type_, "XShape"); - auto* xshape2 = - VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); + PMNode* xshape1 = nullptr; + PMNode* xshape2 = nullptr; + if (transpose_type_ == "transpose2") { + xshape1 = + VarNode("xshape1")->assert_is_op_output(transpose_type_, "XShape"); + xshape2 = + VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); + } auto* transpose1 = OpNode("transpose1", transpose_type_)->assert_is_op(transpose_type_); @@ -45,14 +49,14 @@ void TransposeSoftmaxTransposeFuser::BuildPattern() { // create topology. *x1 >> *transpose1 >> *y1 >> *softmax >> *y2 >> *transpose2 >> *out; - *transpose1 >> *xshape1; - *transpose2 >> *xshape2; + if (xshape1) *transpose1 >> *xshape1; + if (xshape2) *transpose2 >> *xshape2; // nodes to remove y1->AsIntermediate(); y2->AsIntermediate(); - xshape1->AsIntermediate(); - xshape2->AsIntermediate(); + if (xshape1) xshape1->AsIntermediate(); + if (xshape2) xshape2->AsIntermediate(); transpose1->AsIntermediate(); softmax->AsIntermediate(); transpose2->AsIntermediate(); diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 19d3f3fd762267c1ef21e639ce62824531f85cef..9b6997f1efd549836c550c40b5a2f196b352cfa7 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -15,7 +15,6 @@ add_kernel(batch_norm_compute_arm ARM basic SRCS batch_norm_compute.cc DEPS ${li add_kernel(elementwise_compute_arm ARM basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lrn_compute_arm ARM basic SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(decode_bboxes_compute_arm ARM basic SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(multiclass_nms_compute_arm ARM basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(pool_compute_arm ARM basic SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(split_compute_arm ARM basic SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(concat_compute_arm ARM basic SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -80,7 +79,6 @@ lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS ba lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm) lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm) lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm) -lite_cc_test(test_multiclass_nms_compute_arm SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) diff --git a/lite/kernels/arm/multiclass_nms_compute.cc b/lite/kernels/arm/multiclass_nms_compute.cc deleted file mode 100644 index c36a81d1522c77be8535f8c53cf0657a95f765fa..0000000000000000000000000000000000000000 --- a/lite/kernels/arm/multiclass_nms_compute.cc +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/arm/multiclass_nms_compute.h" -#include -#include "lite/arm/math/funcs.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace arm { - -void MulticlassNmsCompute::Run() { - auto& param = Param(); - // bbox shape : N, M, 4 - // scores shape : N, C, M - const float* bbox_data = param.bbox_data->data(); - const float* conf_data = param.conf_data->data(); - - CHECK_EQ(param.bbox_data->dims().production() % 4, 0); - - std::vector result; - int N = param.bbox_data->dims()[0]; - int M = param.bbox_data->dims()[1]; - std::vector priors(N, M); - int class_num = param.conf_data->dims()[1]; - int background_label = param.background_label; - int keep_top_k = param.keep_top_k; - int nms_top_k = param.nms_top_k; - float score_threshold = param.score_threshold; - float nms_threshold = param.nms_threshold; - float nms_eta = param.nms_eta; - bool share_location = param.share_location; - - lite::arm::math::multiclass_nms(bbox_data, - conf_data, - &result, - priors, - class_num, - background_label, - keep_top_k, - nms_top_k, - score_threshold, - nms_threshold, - nms_eta, - share_location); - lite::LoD* lod = param.out->mutable_lod(); - std::vector lod_info; - lod_info.push_back(0); - std::vector result_corrected; - int tmp_batch_id; - uint64_t num = 0; - for (int i = 0; i < result.size(); ++i) { - if (i == 0) { - tmp_batch_id = result[i]; - } - if (i % 7 == 0) { - if (result[i] == tmp_batch_id) { - ++num; - } else { - lod_info.push_back(num); - ++num; - tmp_batch_id = result[i]; - } - } else { - result_corrected.push_back(result[i]); - } - } - lod_info.push_back(num); - (*lod).push_back(lod_info); - - if (result_corrected.empty()) { - (*lod).clear(); - (*lod).push_back(std::vector({0, 1})); - param.out->Resize({static_cast(1)}); - param.out->mutable_data()[0] = -1.; - } else { - param.out->Resize({static_cast(result_corrected.size() / 6), 6}); - float* out = param.out->mutable_data(); - std::memcpy( - out, result_corrected.data(), sizeof(float) * result_corrected.size()); - } -} - -} // namespace arm -} // namespace kernels -} // namespace lite -} // namespace paddle - -REGISTER_LITE_KERNEL(multiclass_nms, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::MulticlassNmsCompute, - def) - .BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Scores", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .Finalize(); diff --git a/lite/kernels/arm/scale_compute.cc b/lite/kernels/arm/scale_compute.cc index 42a5575cc38358e3967ab1fee46b62cd1550f648..07287111a13abd7e50fcc814ae034e4cccb98f55 100644 --- a/lite/kernels/arm/scale_compute.cc +++ b/lite/kernels/arm/scale_compute.cc @@ -32,6 +32,9 @@ void ScaleCompute::Run() { bias *= scale; } lite::arm::math::scale(x_data, output_data, x_dims.production(), scale, bias); + if (!param.x->lod().empty()) { + param.output->set_lod(param.x->lod()); + } } } // namespace arm diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index a9a85bab911dbcb9bf8f1ee356e5aad60fdd47d0..1a198c1dd592ff1a44437144295ab2a7734d6a6b 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -7,13 +7,15 @@ message(STATUS "compile with lite CUDA kernels") nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) - lite_cc_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) +nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) +lite_cc_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) set(cuda_kernels mul_compute_cuda io_copy_compute_cuda leaky_relu_compute_cuda +yolo_box_compute_cuda ) set(cuda_kernels "${cuda_kernels}" CACHE GLOBAL "cuda kernels") diff --git a/lite/kernels/cuda/yolo_box_compute.cu b/lite/kernels/cuda/yolo_box_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..99fff9a709338b081cbeed484bebbd694e383617 --- /dev/null +++ b/lite/kernels/cuda/yolo_box_compute.cu @@ -0,0 +1,225 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/yolo_box_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +__host__ __device__ inline int GetEntryIndex(int batch, + int an_idx, + int hw_idx, + int an_num, + int an_stride, + int stride, + int entry) { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; +} + +template +__host__ __device__ inline T sigmoid(T x) { + return 1.0 / (1.0 + std::exp(-x)); +} + +template +__host__ __device__ inline void GetYoloBox(T* box, + const T* x, + const int* anchors, + int i, + int j, + int an_idx, + int grid_size, + int input_size, + int index, + int stride, + int img_height, + int img_width) { + box[0] = (i + sigmoid(x[index])) * img_width / grid_size; + box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size; + box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / + input_size; + box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * + img_height / input_size; +} + +template +__host__ __device__ inline void CalcDetectionBox(T* boxes, + T* box, + const int box_idx, + const int img_height, + const int img_width) { + boxes[box_idx] = box[0] - box[2] / 2; + boxes[box_idx + 1] = box[1] - box[3] / 2; + boxes[box_idx + 2] = box[0] + box[2] / 2; + boxes[box_idx + 3] = box[1] + box[3] / 2; + + boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); + boxes[box_idx + 1] = + boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); + boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 + ? boxes[box_idx + 2] + : static_cast(img_width - 1); + boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 + ? boxes[box_idx + 3] + : static_cast(img_height - 1); +} + +template +__host__ __device__ inline void CalcLabelScore(T* scores, + const T* input, + const int label_idx, + const int score_idx, + const int class_num, + const T conf, + const int stride) { + for (int i = 0; i < class_num; ++i) { + scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]); + } +} + +template +__global__ void KeYoloBoxFw(const T* input, + const T* imgsize, + T* boxes, + T* scores, + const float conf_thresh, + const int* anchors, + const int n, + const int h, + const int w, + const int an_num, + const int class_num, + const int box_num, + int input_size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + T box[4]; + for (; tid < n * box_num; tid += stride) { + int grid_num = h * w; + int i = tid / box_num; + int j = (tid % box_num) / grid_num; + int k = (tid % grid_num) / w; + int l = tid % w; + + int an_stride = (5 + class_num) * grid_num; + int img_height = static_cast(imgsize[2 * i]); + int img_width = static_cast(imgsize[2 * i + 1]); + + int obj_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4); + T conf = sigmoid(input[obj_idx]); + if (conf < conf_thresh) { + continue; + } + + int box_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); + GetYoloBox(box, + input, + anchors, + l, + k, + j, + h, + input_size, + box_idx, + grid_num, + img_height, + img_width); + box_idx = (i * box_num + j * grid_num + k * w + l) * 4; + CalcDetectionBox(boxes, box, box_idx, img_height, img_width); + + int label_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); + int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; + CalcLabelScore( + scores, input, label_idx, score_idx, class_num, conf, grid_num); + } +} + +void YoloBoxCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + lite::Tensor* X = param.X; + lite::Tensor* ImgSize = param.ImgSize; + lite::Tensor* Boxes = param.Boxes; + lite::Tensor* Scores = param.Scores; + std::vector anchors = param.anchors; + int class_num = param.class_num; + float conf_thresh = param.conf_thresh; + int downsample_ratio = param.downsample_ratio; + + const float* input = X->data(); + const float* imgsize = ImgSize->data(); + float* boxes = Boxes->mutable_data(TARGET(kCUDA)); + float* scores = Scores->mutable_data(TARGET(kCUDA)); + + const int n = X->dims()[0]; + const int h = X->dims()[2]; + const int w = X->dims()[3]; + const int box_num = Boxes->dims()[1]; + const int an_num = anchors.size() / 2; + int input_size = downsample_ratio * h; + + anchors_.Resize(static_cast({anchors.size()})); + int* d_anchors = anchors_.mutable_data(TARGET(kCUDA)); + CopySync(d_anchors, + anchors.data(), + sizeof(int) * anchors.size(), + IoDirection::HtoD); + + int threads = 512; + int blocks = (n * box_num + threads - 1) / threads; + blocks = blocks > 8 ? 8 : blocks; + + KeYoloBoxFw<<>>(input, + imgsize, + boxes, + scores, + conf_thresh, + d_anchors, + n, + h, + w, + an_num, + class_num, + box_num, + input_size); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(yolo_box, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::YoloBoxCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("ImgSize", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/arm/math/multiclass_nms.h b/lite/kernels/cuda/yolo_box_compute.h similarity index 51% rename from lite/arm/math/multiclass_nms.h rename to lite/kernels/cuda/yolo_box_compute.h index a5f39b64620290290b7cb26c013e42261588f8e4..c8ea9d25b689fc38f440ca80ba02f91c155e9900 100644 --- a/lite/arm/math/multiclass_nms.h +++ b/lite/kernels/cuda/yolo_box_compute.h @@ -13,33 +13,25 @@ // limitations under the License. #pragma once - -#include -#include -#include -#include -#include +#include "lite/core/kernel.h" namespace paddle { namespace lite { -namespace arm { -namespace math { +namespace kernels { +namespace cuda { + +class YoloBoxCompute : public KernelLite { + public: + using param_t = operators::YoloBoxParam; + + void Run() override; + virtual ~YoloBoxCompute() = default; -template -void multiclass_nms(const dtype* bbox_cpu_data, - const dtype* conf_cpu_data, - std::vector* result, - const std::vector& priors, - int class_num, - int background_id, - int keep_topk, - int nms_topk, - float conf_thresh, - float nms_thresh, - float nms_eta, - bool share_location); + private: + lite::Tensor anchors_; +}; -} // namespace math -} // namespace arm +} // namespace cuda +} // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/cuda/yolo_box_compute_test.cc b/lite/kernels/cuda/yolo_box_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5cd957938319cff216014e3a97d7348d223884e7 --- /dev/null +++ b/lite/kernels/cuda/yolo_box_compute_test.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/yolo_box_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +inline static float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } + +inline static void get_yolo_box(float* box, + const float* x, + const int* anchors, + int i, + int j, + int an_idx, + int grid_size, + int input_size, + int index, + int stride, + int img_height, + int img_width) { + box[0] = (i + sigmoid(x[index])) * img_width / grid_size; + box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size; + box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / + input_size; + box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * + img_height / input_size; +} + +inline static int get_entry_index(int batch, + int an_idx, + int hw_idx, + int an_num, + int an_stride, + int stride, + int entry) { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; +} + +inline static void calc_detection_box(float* boxes, + float* box, + const int box_idx, + const int img_height, + const int img_width) { + boxes[box_idx] = box[0] - box[2] / 2; + boxes[box_idx + 1] = box[1] - box[3] / 2; + boxes[box_idx + 2] = box[0] + box[2] / 2; + boxes[box_idx + 3] = box[1] + box[3] / 2; + + boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); + boxes[box_idx + 1] = + boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); + boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 + ? boxes[box_idx + 2] + : static_cast(img_width - 1); + boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 + ? boxes[box_idx + 3] + : static_cast(img_height - 1); +} + +inline static void calc_label_score(float* scores, + const float* input, + const int label_idx, + const int score_idx, + const int class_num, + const float conf, + const int stride) { + for (int i = 0; i < class_num; i++) { + scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]); + } +} + +template +static void YoloBoxRef(const T* input, + const T* imgsize, + T* boxes, + T* scores, + const float conf_thresh, + const int* anchors, + const int n, + const int h, + const int w, + const int an_num, + const int class_num, + const int box_num, + int input_size) { + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + float box[4]; + + for (int i = 0; i < n; i++) { + int img_height = static_cast(imgsize[2 * i]); + int img_width = static_cast(imgsize[2 * i + 1]); + + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj_idx = + get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 4); + float conf = sigmoid(input[obj_idx]); + if (conf < conf_thresh) { + continue; + } + + int box_idx = + get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 0); + get_yolo_box(box, + input, + anchors, + l, + k, + j, + h, + input_size, + box_idx, + stride, + img_height, + img_width); + box_idx = (i * box_num + j * stride + k * w + l) * 4; + calc_detection_box(boxes, box, box_idx, img_height, img_width); + + int label_idx = + get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 5); + int score_idx = (i * box_num + j * stride + k * w + l) * class_num; + calc_label_score( + scores, input, label_idx, score_idx, class_num, conf, stride); + } + } + } + } +} + +TEST(yolo_box, normal) { + YoloBoxCompute yolo_box_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::YoloBoxParam param; + + lite::Tensor x, sz, x_cpu, sz_cpu; + lite::Tensor boxes, scores, boxes_cpu, scores_cpu; + lite::Tensor x_ref, sz_ref, boxes_ref, scores_ref; + int s = 3, cls = 4; + int n = 1, c = s * (5 + cls), h = 16, w = 16; + param.anchors = {2, 3, 4, 5, 8, 10}; + param.downsample_ratio = 2; + param.conf_thresh = 0.5; + param.class_num = cls; + int m = h * w * param.anchors.size() / 2; + + x.Resize({n, c, h, w}); + sz.Resize({1, 2}); + boxes.Resize({n, m, 4}); + scores.Resize({n, cls, m}); + + x_cpu.Resize({n, c, h, w}); + sz_cpu.Resize({1, 2}); + boxes_cpu.Resize({n, m, 4}); + scores_cpu.Resize({n, cls, m}); + + x_ref.Resize({n, c, h, w}); + sz_ref.Resize({1, 2}); + boxes_ref.Resize({n, m, 4}); + scores_ref.Resize({n, cls, m}); + + auto* x_data = x.mutable_data(TARGET(kCUDA)); + auto* sz_data = sz.mutable_data(TARGET(kCUDA)); + auto* boxes_data = boxes.mutable_data(TARGET(kCUDA)); + auto* scores_data = scores.mutable_data(TARGET(kCUDA)); + + float* x_cpu_data = x_cpu.mutable_data(); + float* sz_cpu_data = sz_cpu.mutable_data(); + float* boxes_cpu_data = boxes_cpu.mutable_data(); + float* scores_cpu_data = scores_cpu.mutable_data(); + + float* x_ref_data = x_ref.mutable_data(); + float* sz_ref_data = sz_ref.mutable_data(); + float* boxes_ref_data = boxes_ref.mutable_data(); + float* scores_ref_data = scores_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i - 5.0; + x_ref_data[i] = i - 5.0; + } + sz_cpu_data[0] = 16; + sz_cpu_data[1] = 32; + sz_ref_data[0] = 16; + sz_ref_data[1] = 32; + + x.Assign(x_cpu_data, x_cpu.dims()); + sz.Assign(sz_cpu_data, sz_cpu.dims()); + + param.X = &x; + param.ImgSize = &sz; + param.Boxes = &boxes; + param.Scores = &scores; + yolo_box_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + yolo_box_kernel.SetContext(std::move(ctx)); + yolo_box_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync(boxes_cpu_data, + boxes_data, + sizeof(float) * boxes.numel(), + IoDirection::DtoH); + CopySync(scores_cpu_data, + scores_data, + sizeof(float) * scores.numel(), + IoDirection::DtoH); + + YoloBoxRef(x_ref_data, + sz_ref_data, + boxes_ref_data, + scores_ref_data, + param.conf_thresh, + param.anchors.data(), + n, + h, + w, + param.anchors.size() / 2, + cls, + m, + param.downsample_ratio * h); + + for (int i = 0; i < boxes.numel(); i++) { + EXPECT_NEAR(boxes_cpu_data[i], boxes_ref_data[i], 1e-5); + } + for (int i = 0; i < scores.numel(); i++) { + EXPECT_NEAR(scores_cpu_data[i], scores_ref_data[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index abd96317cc2180ecf94e99835ab89216762b8f52..5bcaafeabbe79b0608eca0388f2f0e8f185b108f 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -3,5 +3,7 @@ message(STATUS "compile with lite host kernels") add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps}) add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) +add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps}) lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host any) +lite_cc_test(test_multiclass_nms_compute_host SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_host any) diff --git a/lite/arm/math/multiclass_nms.cc b/lite/kernels/host/multiclass_nms_compute.cc similarity index 74% rename from lite/arm/math/multiclass_nms.cc rename to lite/kernels/host/multiclass_nms_compute.cc index 3baeb2d84439acabf90be6ab31bb3cf78c8dc9c6..0d490d6011ab2f8e9e74f0e3994e9fd696298553 100644 --- a/lite/arm/math/multiclass_nms.cc +++ b/lite/kernels/host/multiclass_nms_compute.cc @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/arm/math/multiclass_nms.h" -#include "lite/arm/math/funcs.h" +#include "lite/kernels/host/multiclass_nms_compute.h" +#include +#include +#include namespace paddle { namespace lite { -namespace arm { -namespace math { +namespace kernels { +namespace host { template static bool sort_score_pair_descend(const std::pair& pair1, @@ -269,31 +271,92 @@ void multiclass_nms(const dtype* bbox_cpu_data, } } -template float jaccard_overlap(const float* bbox1, const float* bbox2); - -template void apply_nms_fast(const float* bboxes, - const float* scores, - int num, - float score_threshold, - float nms_threshold, - float eta, - int top_k, - std::vector* indices); - -template void multiclass_nms(const float* bbox_cpu_data, - const float* conf_cpu_data, - std::vector* result, - const std::vector& priors, - int class_num, - int background_id, - int keep_topk, - int nms_topk, - float conf_thresh, - float nms_thresh, - float nms_eta, - bool share_location); - -} // namespace math -} // namespace arm +void MulticlassNmsCompute::Run() { + auto& param = Param(); + // bbox shape : N, M, 4 + // scores shape : N, C, M + const float* bbox_data = param.bbox_data->data(); + const float* conf_data = param.conf_data->data(); + + CHECK_EQ(param.bbox_data->dims().production() % 4, 0); + + std::vector result; + int N = param.bbox_data->dims()[0]; + int M = param.bbox_data->dims()[1]; + std::vector priors(N, M); + int class_num = param.conf_data->dims()[1]; + int background_label = param.background_label; + int keep_top_k = param.keep_top_k; + int nms_top_k = param.nms_top_k; + float score_threshold = param.score_threshold; + float nms_threshold = param.nms_threshold; + float nms_eta = param.nms_eta; + bool share_location = param.share_location; + + multiclass_nms(bbox_data, + conf_data, + &result, + priors, + class_num, + background_label, + keep_top_k, + nms_top_k, + score_threshold, + nms_threshold, + nms_eta, + share_location); + + lite::LoD lod; + std::vector lod_info; + lod_info.push_back(0); + std::vector result_corrected; + int tmp_batch_id; + uint64_t num = 0; + for (int i = 0; i < result.size(); ++i) { + if (i == 0) { + tmp_batch_id = result[i]; + } + if (i % 7 == 0) { + if (result[i] == tmp_batch_id) { + ++num; + } else { + lod_info.push_back(num); + ++num; + tmp_batch_id = result[i]; + } + } else { + result_corrected.push_back(result[i]); + } + } + lod_info.push_back(num); + lod.push_back(lod_info); + if (result_corrected.empty()) { + lod.clear(); + lod.push_back(std::vector({0, 1})); + param.out->Resize({static_cast(1)}); + param.out->mutable_data()[0] = -1.; + param.out->set_lod(lod); + } else { + param.out->Resize({static_cast(result_corrected.size() / 6), 6}); + float* out = param.out->mutable_data(); + std::memcpy( + out, result_corrected.data(), sizeof(float) * result_corrected.size()); + param.out->set_lod(lod); + } +} + +} // namespace host +} // namespace kernels } // namespace lite } // namespace paddle + +REGISTER_LITE_KERNEL(multiclass_nms, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::MulticlassNmsCompute, + def) + .BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); diff --git a/lite/kernels/arm/multiclass_nms_compute.h b/lite/kernels/host/multiclass_nms_compute.h similarity index 85% rename from lite/kernels/arm/multiclass_nms_compute.h rename to lite/kernels/host/multiclass_nms_compute.h index 6b77e216fa5213be71a7535c47d89ed09c3c1148..9391a0b2b5231c30c6f490039ec956326efa5bf9 100644 --- a/lite/kernels/arm/multiclass_nms_compute.h +++ b/lite/kernels/host/multiclass_nms_compute.h @@ -13,26 +13,24 @@ // limitations under the License. #pragma once -#include +#include #include "lite/core/kernel.h" #include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { class MulticlassNmsCompute - : public KernelLite { + : public KernelLite { public: - using param_t = operators::MulticlassNmsParam; - void Run() override; virtual ~MulticlassNmsCompute() = default; }; -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/arm/multiclass_nms_compute_test.cc b/lite/kernels/host/multiclass_nms_compute_test.cc similarity index 95% rename from lite/kernels/arm/multiclass_nms_compute_test.cc rename to lite/kernels/host/multiclass_nms_compute_test.cc index b0352f77c5a5f1f26d058ffee040e696fd55e7f9..37c04bc2902cb0fc1d67095c48ac40edf695f830 100644 --- a/lite/kernels/arm/multiclass_nms_compute_test.cc +++ b/lite/kernels/host/multiclass_nms_compute_test.cc @@ -12,19 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/arm/multiclass_nms_compute.h" +#include "lite/kernels/host/multiclass_nms_compute.h" #include -#include #include -#include #include #include -#include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { template static bool sort_score_pair_descend(const std::pair& pair1, @@ -279,21 +276,21 @@ void multiclass_nms_compute_ref(const operators::MulticlassNmsParam& param, } } -TEST(multiclass_nms_arm, retrive_op) { +TEST(multiclass_nms_host, init) { + MulticlassNmsCompute multiclass_nms; + ASSERT_EQ(multiclass_nms.precision(), PRECISION(kFloat)); + ASSERT_EQ(multiclass_nms.target(), TARGET(kHost)); +} + +TEST(multiclass_nms_host, retrive_op) { auto multiclass_nms = - KernelRegistry::Global().Create( + KernelRegistry::Global().Create( "multiclass_nms"); ASSERT_FALSE(multiclass_nms.empty()); ASSERT_TRUE(multiclass_nms.front()); } -TEST(multiclass_nms_arm, init) { - MulticlassNmsCompute multiclass_nms; - ASSERT_EQ(multiclass_nms.precision(), PRECISION(kFloat)); - ASSERT_EQ(multiclass_nms.target(), TARGET(kARM)); -} - -TEST(multiclass_nms_arm, compute) { +TEST(multiclass_nms_host, compute) { MulticlassNmsCompute multiclass_nms; operators::MulticlassNmsParam param; lite::Tensor bbox, conf, out; @@ -306,9 +303,6 @@ TEST(multiclass_nms_arm, compute) { DDim* bbox_dim; DDim* conf_dim; int M = priors[0]; - // for (int i = 0; i < priors.size(); ++i) { - // M += priors[i]; - //} if (share_location) { bbox_dim = new DDim({N, M, 4}); } else { @@ -368,9 +362,9 @@ TEST(multiclass_nms_arm, compute) { } } -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle -USE_LITE_KERNEL(multiclass_nms, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(multiclass_nms, kHost, kFloat, kNCHW, def); diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index a78424dede7a54f22de0a8c007c02c89f6982733..482d2d9520d2e5d381a60ff9b0faa053d0a03658 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -33,7 +33,6 @@ endif() lite_cc_test(test_sgemm SRCS test_sgemm.cc DEPS ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_multiclass_nms_compute SRCS multiclass_nms_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_bilinear_interp_compute SRCS bilinear_interp_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_nearest_interp_compute SRCS nearest_interp_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/multiclass_nms_compute_test.cc b/lite/tests/kernels/multiclass_nms_compute_test.cc deleted file mode 100644 index 1658acd0c782880b4430db541b0e62daec088b74..0000000000000000000000000000000000000000 --- a/lite/tests/kernels/multiclass_nms_compute_test.cc +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/api/paddle_use_kernels.h" -#include "lite/api/paddle_use_ops.h" -#include "lite/core/arena/framework.h" - -bool read_file(std::vector* result, const char* file_name) { - std::ifstream infile(file_name); - - if (!infile.good()) { - std::cout << "Cannot open " << file_name << std::endl; - return false; - } - - LOG(INFO) << "found filename: " << file_name; - std::string line; - - while (std::getline(infile, line)) { - (*result).push_back(static_cast(atof(line.c_str()))); - } - - return true; -} - -const char* bboxes_file = "multiclass_nms_bboxes_file.txt"; -const char* scores_file = "multiclass_nms_scores_file.txt"; -const char* out_file = "multiclass_nms_out_file.txt"; - -namespace paddle { -namespace lite { -class MulticlassNmsComputeTester : public arena::TestCase { - protected: - // common attributes for this op. - std::string bbox_ = "BBoxes"; - std::string conf_ = "Scores"; - std::string out_ = "Out"; - std::vector priors_; - int class_num_; - int background_id_; - int keep_topk_; - int nms_topk_; - float conf_thresh_; - float nms_thresh_; - float nms_eta_; - bool share_location_; - DDim bbox_dims_; - DDim conf_dims_; - - public: - MulticlassNmsComputeTester(const Place& place, - const std::string& alias, - std::vector priors, - int class_num, - int background_id, - int keep_topk, - int nms_topk, - float conf_thresh, - float nms_thresh, - float nms_eta, - bool share_location, - DDim bbox_dims, - DDim conf_dims) - : TestCase(place, alias), - priors_(priors), - class_num_(class_num), - background_id_(background_id), - keep_topk_(keep_topk), - nms_topk_(nms_topk), - conf_thresh_(conf_thresh), - nms_thresh_(nms_thresh), - nms_eta_(nms_eta), - share_location_(share_location), - bbox_dims_(bbox_dims), - conf_dims_(conf_dims) {} - - void RunBaseline(Scope* scope) override { - std::vector vbbox; - std::vector vscore; - std::vector vout; - - if (!read_file(&vout, out_file)) { - LOG(ERROR) << "load ground truth failed"; - return; - } - - auto* out = scope->NewTensor(out_); - CHECK(out); - out->Resize(DDim({static_cast(vout.size() / 6), 6})); - auto* out_data = out->mutable_data(); - memcpy(out_data, vout.data(), vout.size() * sizeof(float)); - out->mutable_lod()->push_back(std::vector({0, 10})); - } - - void PrepareOpDesc(cpp::OpDesc* op_desc) { - op_desc->SetType("multiclass_nms"); - op_desc->SetInput("BBoxes", {bbox_}); - op_desc->SetInput("Scores", {conf_}); - op_desc->SetOutput("Out", {out_}); - op_desc->SetAttr("background_label", background_id_); - op_desc->SetAttr("keep_top_k", keep_topk_); - op_desc->SetAttr("nms_top_k", nms_topk_); - op_desc->SetAttr("score_threshold", conf_thresh_); - op_desc->SetAttr("nms_threshold", nms_thresh_); - op_desc->SetAttr("nms_eta", nms_eta_); - op_desc->SetAttr("share_location", share_location_); - } - - void PrepareData() override { - std::vector bbox_data; - std::vector conf_data; - - if (!read_file(&bbox_data, bboxes_file)) { - LOG(ERROR) << "load bbox file failed"; - return; - } - if (!read_file(&conf_data, scores_file)) { - LOG(ERROR) << "load score file failed"; - return; - } - - SetCommonTensor(bbox_, bbox_dims_, bbox_data.data()); - SetCommonTensor(conf_, conf_dims_, conf_data.data()); - } -}; - -void test_multiclass_nms(Place place) { - int keep_top_k = 200; - int nms_top_k = 400; - float nms_eta = 1.; - float score_threshold = 0.009999999776482582; - int background_label = 0; - float nms_threshold = 0.44999998807907104; - int N = 1; - int M = 1917; - int class_num = 21; - bool share_location = true; - std::vector priors(N, M); - - std::unique_ptr tester( - new MulticlassNmsComputeTester(place, - "def", - priors, - class_num, - background_label, - keep_top_k, - nms_top_k, - score_threshold, - nms_threshold, - nms_eta, - share_location, - DDim({N, M, 4}), - DDim({N, class_num, M}))); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); -} - -TEST(MulticlassNms, precision) { -#ifdef LITE_WITH_X86 - Place place(TARGET(kX86)); -#endif -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); - test_multiclass_nms(place); -#endif -} - -} // namespace lite -} // namespace paddle diff --git a/lite/tools/ci_build.sh b/lite/tools/ci_build.sh index 80432f4378cd615109c0647b8667cc69417bbe6a..433e08c2516ac66a376bd18c4bf0c7317948cc19 100755 --- a/lite/tools/ci_build.sh +++ b/lite/tools/ci_build.sh @@ -524,22 +524,6 @@ function build_npu { fi } -function __prepare_multiclass_nms_test_files { - local port=$1 - local adb_work_dir="/data/local/tmp" - - wget --no-check-certificate https://raw.githubusercontent.com/jiweibo/TestData/master/multiclass_nms_bboxes_file.txt \ - -O lite/tests/kernels/multiclass_nms_bboxes_file.txt - wget --no-check-certificate https://raw.githubusercontent.com/jiweibo/TestData/master/multiclass_nms_scores_file.txt \ - -O lite/tests/kernels/multiclass_nms_scores_file.txt - wget --no-check-certificate https://raw.githubusercontent.com/jiweibo/TestData/master/multiclass_nms_out_file.txt \ - -O lite/tests/kernels/multiclass_nms_out_file.txt - - adb -s emulator-${port} push lite/tests/kernels/multiclass_nms_bboxes_file.txt ${adb_work_dir} - adb -s emulator-${port} push lite/tests/kernels/multiclass_nms_scores_file.txt ${adb_work_dir} - adb -s emulator-${port} push lite/tests/kernels/multiclass_nms_out_file.txt ${adb_work_dir} -} - # $1: ARM_TARGET_OS in "android" , "armlinux" # $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf" # $3: ARM_TARGET_LANG in "gcc" "clang" @@ -562,9 +546,6 @@ function test_arm { return 0 fi - echo "prepare multiclass_nms_test files..." - __prepare_multiclass_nms_test_files $port - # prepare for CXXApi test local adb="adb -s emulator-${port}" $adb shell mkdir -p /data/local/tmp/lite_naive_model_opt