diff --git a/lite/backends/arm/math/yolo_box.cc b/lite/backends/arm/math/yolo_box.cc index 72e67cf69331ac2e0fa6edc7c8cd4a99ee763071..7ddb262480bbc427cda68b199a39fdef50a214c3 100644 --- a/lite/backends/arm/math/yolo_box.cc +++ b/lite/backends/arm/math/yolo_box.cc @@ -108,7 +108,7 @@ void yolobox(lite::Tensor* X, auto anchors_data = anchors.data(); const float* X_data = X->data(); - float* ImgSize_data = ImgSize->mutable_data(); + int* ImgSize_data = ImgSize->mutable_data(); float* Boxes_data = Boxes->mutable_data(); @@ -116,8 +116,8 @@ void yolobox(lite::Tensor* X, float box[4]; for (int i = 0; i < n; i++) { - int img_height = static_cast(ImgSize_data[2 * i]); - int img_width = static_cast(ImgSize_data[2 * i + 1]); + int img_height = ImgSize_data[2 * i]; + int img_width = ImgSize_data[2 * i + 1]; for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { diff --git a/lite/operators/yolo_box_op.cc b/lite/operators/yolo_box_op.cc index 068cdf043193d3771334f0e7bac33ea190edf5e1..de2ce77dfbf610707b3c62f7fa02a8413459893b 100644 --- a/lite/operators/yolo_box_op.cc +++ b/lite/operators/yolo_box_op.cc @@ -31,6 +31,19 @@ bool YoloBoxOp::CheckShape() const { CHECK_OR_FALSE(ImgSize); CHECK_OR_FALSE(Boxes); CHECK_OR_FALSE(Scores); + + auto dim_x = X->dims(); + auto dim_imgsize = ImgSize->dims(); + std::vector anchors = param_.anchors; + int anchor_num = anchors.size() / 2; + auto class_num = param_.class_num; + CHECK_OR_FALSE(dim_x.size() == 4); + CHECK_OR_FALSE(dim_x[1] == anchor_num * (5 + class_num)); + CHECK_OR_FALSE(dim_imgsize[0] == dim_x[0]); + CHECK_OR_FALSE(dim_imgsize[1] == 2); + CHECK_OR_FALSE(anchors.size() > 0 && anchors.size() % 2 == 0); + CHECK_OR_FALSE(class_num > 0); + return true; } bool YoloBoxOp::InferShape() const { diff --git a/lite/tests/kernels/yolo_box_compute_test.cc b/lite/tests/kernels/yolo_box_compute_test.cc index a051e06b6bcb23647f8b9f467b9f76a751fecec4..2e98ce96cef479d55e77acebbe464d9a56f92934 100644 --- a/lite/tests/kernels/yolo_box_compute_test.cc +++ b/lite/tests/kernels/yolo_box_compute_test.cc @@ -101,7 +101,7 @@ class YoloBoxComputeTester : public arena::TestCase { float conf_thresh_ = 0.f; int downsample_ratio_ = 0; - DDim _dims0_{{1, 2, 2, 1}}; + DDim _dims0_{{1, 255, 13, 13}}; DDim _dims1_{{1, 2}}; public: @@ -115,7 +115,10 @@ class YoloBoxComputeTester : public arena::TestCase { anchors_(anchors), class_num_(class_num), conf_thresh_(conf_thresh), - downsample_ratio_(downsample_ratio) {} + downsample_ratio_(downsample_ratio) { + int anchor_num = anchors_.size() / 2; + _dims0_[1] = anchor_num * (5 + class_num); + } void RunBaseline(Scope* scope) override { const lite::Tensor* X = scope->FindTensor(input0_); @@ -149,14 +152,14 @@ class YoloBoxComputeTester : public arena::TestCase { auto anchors_data = anchors.data(); const float* in_data = in->data(); - const float* imgsize_data = imgsize->data(); + const int* imgsize_data = imgsize->data(); float* boxes_data = boxes->mutable_data(); float* scores_data = scores->mutable_data(); float box[4]; for (int i = 0; i < n; i++) { - int img_height = static_cast(imgsize_data[2 * i]); - int img_width = static_cast(imgsize_data[2 * i + 1]); + int img_height = imgsize_data[2 * i]; + int img_width = imgsize_data[2 * i + 1]; for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { @@ -218,7 +221,7 @@ class YoloBoxComputeTester : public arena::TestCase { } std::vector data1(_dims1_.production()); for (int i = 0; i < _dims1_.production(); i++) { - data1[i] = i + 8; + data1[i] = 608; } SetCommonTensor(input0_, _dims0_, data0.data()); SetCommonTensor(input1_, _dims1_, data1.data()); @@ -227,10 +230,9 @@ class YoloBoxComputeTester : public arena::TestCase { void test_yolobox(Place place) { for (int class_num : {1, 2, 3, 4}) { - for (float conf_thresh : {0.5, 0.2, 0.7}) { - for (int downsample_ratio : {1, 2, 3}) { - std::vector anchor({1, 2, 3, 4}); - + for (float conf_thresh : {0.01, 0.2, 0.7}) { + for (int downsample_ratio : {16, 32}) { + std::vector anchor({10, 13, 16, 30}); std::unique_ptr tester(new YoloBoxComputeTester( place, "def", anchor, class_num, conf_thresh, downsample_ratio)); arena::Arena arena(std::move(tester), place, 2e-5);