未验证 提交 1b1d7a83 编写于 作者: J juncaipeng 提交者: GitHub

fix yolo_box bug (#2034)

* fix yolo_box bug, test=develop

* fix test bug for yolo_box, test=develop
上级 a2a5c8b0
......@@ -108,7 +108,7 @@ void yolobox(lite::Tensor* X,
auto anchors_data = anchors.data();
const float* X_data = X->data<float>();
float* ImgSize_data = ImgSize->mutable_data<float>();
int* ImgSize_data = ImgSize->mutable_data<int>();
float* Boxes_data = Boxes->mutable_data<float>();
......@@ -116,8 +116,8 @@ void yolobox(lite::Tensor* X,
float box[4];
for (int i = 0; i < n; i++) {
int img_height = static_cast<int>(ImgSize_data[2 * i]);
int img_width = static_cast<int>(ImgSize_data[2 * i + 1]);
int img_height = ImgSize_data[2 * i];
int img_width = ImgSize_data[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
......
......@@ -31,6 +31,19 @@ bool YoloBoxOp::CheckShape() const {
CHECK_OR_FALSE(ImgSize);
CHECK_OR_FALSE(Boxes);
CHECK_OR_FALSE(Scores);
auto dim_x = X->dims();
auto dim_imgsize = ImgSize->dims();
std::vector<int> anchors = param_.anchors;
int anchor_num = anchors.size() / 2;
auto class_num = param_.class_num;
CHECK_OR_FALSE(dim_x.size() == 4);
CHECK_OR_FALSE(dim_x[1] == anchor_num * (5 + class_num));
CHECK_OR_FALSE(dim_imgsize[0] == dim_x[0]);
CHECK_OR_FALSE(dim_imgsize[1] == 2);
CHECK_OR_FALSE(anchors.size() > 0 && anchors.size() % 2 == 0);
CHECK_OR_FALSE(class_num > 0);
return true;
}
bool YoloBoxOp::InferShape() const {
......
......@@ -101,7 +101,7 @@ class YoloBoxComputeTester : public arena::TestCase {
float conf_thresh_ = 0.f;
int downsample_ratio_ = 0;
DDim _dims0_{{1, 2, 2, 1}};
DDim _dims0_{{1, 255, 13, 13}};
DDim _dims1_{{1, 2}};
public:
......@@ -115,7 +115,10 @@ class YoloBoxComputeTester : public arena::TestCase {
anchors_(anchors),
class_num_(class_num),
conf_thresh_(conf_thresh),
downsample_ratio_(downsample_ratio) {}
downsample_ratio_(downsample_ratio) {
int anchor_num = anchors_.size() / 2;
_dims0_[1] = anchor_num * (5 + class_num);
}
void RunBaseline(Scope* scope) override {
const lite::Tensor* X = scope->FindTensor(input0_);
......@@ -149,14 +152,14 @@ class YoloBoxComputeTester : public arena::TestCase {
auto anchors_data = anchors.data();
const float* in_data = in->data<float>();
const float* imgsize_data = imgsize->data<float>();
const int* imgsize_data = imgsize->data<int>();
float* boxes_data = boxes->mutable_data<float>();
float* scores_data = scores->mutable_data<float>();
float box[4];
for (int i = 0; i < n; i++) {
int img_height = static_cast<int>(imgsize_data[2 * i]);
int img_width = static_cast<int>(imgsize_data[2 * i + 1]);
int img_height = imgsize_data[2 * i];
int img_width = imgsize_data[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
......@@ -218,7 +221,7 @@ class YoloBoxComputeTester : public arena::TestCase {
}
std::vector<int> data1(_dims1_.production());
for (int i = 0; i < _dims1_.production(); i++) {
data1[i] = i + 8;
data1[i] = 608;
}
SetCommonTensor(input0_, _dims0_, data0.data());
SetCommonTensor(input1_, _dims1_, data1.data());
......@@ -227,10 +230,9 @@ class YoloBoxComputeTester : public arena::TestCase {
void test_yolobox(Place place) {
for (int class_num : {1, 2, 3, 4}) {
for (float conf_thresh : {0.5, 0.2, 0.7}) {
for (int downsample_ratio : {1, 2, 3}) {
std::vector<int> anchor({1, 2, 3, 4});
for (float conf_thresh : {0.01, 0.2, 0.7}) {
for (int downsample_ratio : {16, 32}) {
std::vector<int> anchor({10, 13, 16, 30});
std::unique_ptr<arena::TestCase> tester(new YoloBoxComputeTester(
place, "def", anchor, class_num, conf_thresh, downsample_ratio));
arena::Arena arena(std::move(tester), place, 2e-5);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册