From 961279b5d229df0f1a29b01fbcc666d891fd0a18 Mon Sep 17 00:00:00 2001 From: Di Wu Date: Wed, 14 Apr 2021 12:39:20 +0800 Subject: [PATCH] Allow yolov5 non square letter box input with args -s and -c (#617) * Allow yolov5 non square letterbox input. * free malloc and retract getopt args for letterbox rows and cols. --- examples/tm_yolov5s.cpp | 270 +++++++++++++++---------------- examples/tm_yolov5s_timvx.cpp | 292 ++++++++++++++++------------------ 2 files changed, 267 insertions(+), 295 deletions(-) diff --git a/examples/tm_yolov5s.cpp b/examples/tm_yolov5s.cpp index 0a08a1c4..43348a75 100644 --- a/examples/tm_yolov5s.cpp +++ b/examples/tm_yolov5s.cpp @@ -20,9 +20,10 @@ /* * Copyright (c) 2021, OPEN AI LAB * Author: xwwang@openailab.com + * + * Author: stevenwudi@fiture.com */ -#include #include #include #include @@ -133,108 +134,16 @@ static void nms_sorted_bboxes(const std::vector& faceobjects, std::vecto } } -void get_input_data_focas(const char* image_file, float* input_data, int img_h, int img_w, const float* mean, const float* scale) -{ - cv::Mat sample = cv::imread(image_file, 1); - cv::Mat img; - - /* convert to RGB */ - if (sample.channels() == 1) - cv::cvtColor(sample, img, cv::COLOR_GRAY2RGB); - else - cv::cvtColor(sample, img, cv::COLOR_BGR2RGB); - - /* letterbox process */ - float letterbox_size = img_h; - int resize_h = 0; - int resize_w = 0; - if (img.rows > img.cols) - { - resize_h = letterbox_size; - resize_w = int(img.cols * (letterbox_size / img.rows)); - } - else - { - resize_h = int(img.rows * (letterbox_size / img.cols)); - resize_w = letterbox_size; - } - - cv::resize(img, img, cv::Size(resize_w, resize_h)); - img.convertTo(img, CV_32FC3); - cv::Mat resize_img(letterbox_size, letterbox_size, CV_32FC3, - cv::Scalar(0.5/scale[0] + mean[0], 0.5/scale[1] + mean[1], 0.5/ scale[2] + mean[2])); - int dh = int((letterbox_size - resize_h) / 2); - int dw = int((letterbox_size - resize_w) / 2); - - for (int h = 0; h < resize_h; h++) - { - for (int w = 0; w < resize_w; w++) - { - for (int c = 0; c < 3; ++c) - { - int in_index = h * resize_w * 3 + w * 3 + c; - int out_index = (dh + h) * letterbox_size * 3 + (dw + w) * 3 + c; - - (( float* )resize_img.data)[out_index] = (( float* )img.data)[in_index]; - } - } - } - - resize_img.convertTo(resize_img, CV_32FC3); - float* img_data = (float* )resize_img.data; - float* input_temp = (float* )malloc(3 * letterbox_size * letterbox_size * sizeof(float)); - /* nhwc to nchw */ - for (int h = 0; h < img_h; h++) - { - for (int w = 0; w < img_w; w++) - { - for (int c = 0; c < 3; c++) - { - int in_index = h * img_w * 3 + w * 3 + c; - int out_index = c * img_h * img_w + h * img_w + w; - input_temp[out_index] = (img_data[in_index] - mean[c]) * scale[c]; - } - } - } - - /* focus process */ - int input_size = letterbox_size / 2; - for (int i = 0; i < 2; i++) - { - for (int g = 0; g < 2; g++) - { - for (int c = 0; c < 3; c++) - { - for (int w = 0; w < input_size; w++) - { - for (int h = 0; h < input_size; h++) - { - int in_index = i + g * letterbox_size + c * letterbox_size * letterbox_size + w * 2 * letterbox_size + h * 2; - int out_index = i * 2 * 3 * input_size * input_size + - g * 3 * input_size * input_size + - c * input_size * input_size + - w * input_size + - h; - input_data[out_index] = input_temp[in_index]; - } - } - } - } - } - - free(input_temp); -} - -static void generate_proposals(int stride, const float* feat, float prob_threshold, std::vector& objects) -{ +static void generate_proposals(int stride, const float* feat, float prob_threshold, std::vector& objects, + int letterbox_cols, int letterbox_rows){ static float anchors[18] = {10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326}; int anchor_num = 3; - int feat_w = 640 / stride; - int feat_h = 640 / stride; + int feat_w = letterbox_cols / stride; + int feat_h = letterbox_rows / stride; int cls_num = 80; - int anchor_group = 0; + int anchor_group; if(stride == 8) anchor_group = 1; if(stride == 16) @@ -297,15 +206,15 @@ static void generate_proposals(int stride, const float* feat, float prob_thresho static void draw_objects(const cv::Mat& bgr, const std::vector& objects) { static const char* class_names[] = { - "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", - "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", - "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", - "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", - "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", - "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", - "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", - "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", - "hair drier", "toothbrush" + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", + "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", + "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", + "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", + "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", + "hair drier", "toothbrush" }; cv::Mat image = bgr.clone(); @@ -345,21 +254,105 @@ static void draw_objects(const cv::Mat& bgr, const std::vector& objects) void show_usage() { fprintf( - stderr, - "[Usage]: [-h]\n [-m model_file] [-i image_file] [-r repeat_count] [-t thread_count] \n"); + stderr, + "[Usage]: [-h]\n [-m model_file] [-i image_file] [-r repeat_count] [-t thread_count]\n"); +} + +void get_input_data_focus(const char* image_file, float* input_data, int letterbox_rows, int letterbox_cols, const float* mean, const float* scale) +{ + cv::Mat sample = cv::imread(image_file, 1); + cv::Mat img; + + if (sample.channels() == 1) + cv::cvtColor(sample, img, cv::COLOR_GRAY2RGB); + else + cv::cvtColor(sample, img, cv::COLOR_BGR2RGB); + + /* letterbox process to support different letterbox size */ + float scale_letterbox; + int resize_rows; + int resize_cols; + if ((letterbox_rows * 1.0 / img.rows) < (letterbox_cols * 1.0 / img.cols)) { + scale_letterbox = letterbox_rows * 1.0 / img.rows; + } else { + scale_letterbox = letterbox_cols * 1.0 / img.cols; + } + resize_cols = int(scale_letterbox * img.cols); + resize_rows = int(scale_letterbox * img.rows); + + cv::resize(img, img, cv::Size(resize_cols, resize_rows)); + img.convertTo(img, CV_32FC3); + // Generate a gray image for letterbox using opencv + cv::Mat img_new(letterbox_cols, letterbox_rows, CV_32FC3,cv::Scalar(0.5/scale[0] + mean[0], 0.5/scale[1] + mean[1], 0.5/ scale[2] + mean[2])); + int top = (letterbox_rows - resize_rows) / 2; + int bot = (letterbox_rows - resize_rows + 1) / 2; + int left = (letterbox_cols - resize_cols) / 2; + int right = (letterbox_cols - resize_cols + 1) / 2; + // Letterbox filling + cv::copyMakeBorder(img, img_new, top, bot, left, right, cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); + + img_new.convertTo(img_new, CV_32FC3); + float* img_data = (float* )img_new.data; + float* input_temp = (float* )malloc(3 * letterbox_cols * letterbox_rows * sizeof(float)); + + /* nhwc to nchw */ + for (int h = 0; h < letterbox_rows; h++) + { + for (int w = 0; w < letterbox_cols; w++) + { + for (int c = 0; c < 3; c++) + { + int in_index = h * letterbox_cols * 3 + w * 3 + c; + int out_index = c * letterbox_rows * letterbox_cols + h * letterbox_cols + w; + input_temp[out_index] = (img_data[in_index] - mean[c]) * scale[c]; + } + } + } + + /* focus process */ + for (int i = 0; i < 2; i++) // corresponding to rows + { + for (int g = 0; g < 2; g++) // corresponding to cols + { + for (int c = 0; c < 3; c++) + { + for (int h = 0; h < letterbox_rows/2; h++) + { + for (int w = 0; w < letterbox_cols/2; w++) + { + int in_index = i + g * letterbox_cols + c * letterbox_cols * letterbox_rows + + h * 2 * letterbox_cols + w * 2; + int out_index = i * 2 * 3 * (letterbox_cols/2) * (letterbox_rows/2) + + g * 3 * (letterbox_cols/2) * (letterbox_rows/2) + + c * (letterbox_cols/2) * (letterbox_rows/2) + + h * (letterbox_cols/2) + + w; + + /* quant to uint8 */ + input_data[out_index] = input_temp[in_index]; + } + } + } + } + } + + free(input_temp); } + int main(int argc, char* argv[]) { const char* model_file = nullptr; const char* image_file = nullptr; - int letterbox_size = 640; - int img_h = letterbox_size; - int img_w = letterbox_size; + int img_c = 3; const float mean[3] = {0, 0, 0}; const float scale[3] = {0.003921, 0.003921, 0.003921}; + // allow none square letterbox, set default letterbox size + int letterbox_rows = 640; + int letterbox_cols = 640; + int repeat_count = 1; int num_thread = 1; @@ -437,8 +430,8 @@ int main(int argc, char* argv[]) return -1; } - int img_size = img_h * img_w * img_c; - int dims[] = {1, 12, int(img_h / 2), int(img_w / 2)}; + int img_size = letterbox_rows * letterbox_cols * img_c; + int dims[] = {1, 12, int(letterbox_rows / 2), int(letterbox_cols / 2)}; float* input_data = ( float* )malloc(img_size * sizeof(float)); tensor_t input_tensor = get_graph_input_tensor(graph, 0, 0); @@ -468,7 +461,7 @@ int main(int argc, char* argv[]) } /* prepare process input data, set the data mem to input tensor */ - get_input_data_focas(image_file, input_data, img_h, img_w, mean, scale); + get_input_data_focus(image_file, input_data, letterbox_rows, letterbox_cols, mean, scale); /* run graph */ double min_time = DBL_MAX; @@ -504,7 +497,7 @@ int main(int argc, char* argv[]) float* p16_data = ( float*)get_tensor_buffer(p16_output); float* p32_data = ( float*)get_tensor_buffer(p32_output); - /* postprocess */ + /* postprocess */ const float prob_threshold = 0.25f; const float nms_threshold = 0.45f; @@ -514,11 +507,11 @@ int main(int argc, char* argv[]) std::vector objects32; std::vector objects; - generate_proposals(32, p32_data, prob_threshold, objects32); + generate_proposals(32, p32_data, prob_threshold, objects32, letterbox_cols, letterbox_rows); proposals.insert(proposals.end(), objects32.begin(), objects32.end()); - generate_proposals(16, p16_data, prob_threshold, objects16); + generate_proposals(16, p16_data, prob_threshold, objects16, letterbox_cols, letterbox_rows); proposals.insert(proposals.end(), objects16.begin(), objects16.end()); - generate_proposals( 8, p8_data, prob_threshold, objects8); + generate_proposals( 8, p8_data, prob_threshold, objects8, letterbox_cols, letterbox_rows); proposals.insert(proposals.end(), objects8.begin(), objects8.end()); qsort_descent_inplace(proposals); @@ -526,28 +519,23 @@ int main(int argc, char* argv[]) nms_sorted_bboxes(proposals, picked, nms_threshold); /* yolov5 draw the result */ - int raw_h = img.rows; - int raw_w = img.cols; - float lb = letterbox_size; - int h0 = 0; - int w0 = 0; - if ( img.rows > img.cols) - { - h0 = lb; - w0 = int(img.cols * (lb / img.rows)); - } - else - { - h0 = int(img.rows * (lb / img.cols)); - w0 = lb; + float scale_letterbox; + int resize_rows; + int resize_cols; + if ((letterbox_rows * 1.0 / img.rows) < (letterbox_cols * 1.0 / img.cols)) { + scale_letterbox = letterbox_rows * 1.0 / img.rows; + } else { + scale_letterbox = letterbox_cols * 1.0 / img.cols; } + resize_cols = int(scale_letterbox * img.cols); + resize_rows = int(scale_letterbox * img.rows); - int tmp_h = (lb - h0) / 2; - int tmp_w = (lb - w0) / 2; + int tmp_h = (letterbox_rows - resize_rows) / 2; + int tmp_w = (letterbox_cols - resize_cols) / 2; - float ratio_x = (float)raw_w / w0; - float ratio_y = (float)raw_h / h0; + float ratio_x = (float)img.rows / resize_rows; + float ratio_y = (float)img.cols / resize_cols; int count = picked.size(); fprintf(stderr, "detection num: %d\n",count); @@ -566,10 +554,10 @@ int main(int argc, char* argv[]) x1 = (x1 - tmp_w) * ratio_x; y1 = (y1 - tmp_h) * ratio_y; - x0 = std::max(std::min(x0, (float)(raw_w - 1)), 0.f); - y0 = std::max(std::min(y0, (float)(raw_h - 1)), 0.f); - x1 = std::max(std::min(x1, (float)(raw_w - 1)), 0.f); - y1 = std::max(std::min(y1, (float)(raw_h - 1)), 0.f); + x0 = std::max(std::min(x0, (float)(img.cols - 1)), 0.f); + y0 = std::max(std::min(y0, (float)(img.rows - 1)), 0.f); + x1 = std::max(std::min(x1, (float)(img.cols - 1)), 0.f); + y1 = std::max(std::min(y1, (float)(img.rows - 1)), 0.f); objects[i].rect.x = x0; objects[i].rect.y = y0; diff --git a/examples/tm_yolov5s_timvx.cpp b/examples/tm_yolov5s_timvx.cpp index 3b83cd71..6ed23667 100644 --- a/examples/tm_yolov5s_timvx.cpp +++ b/examples/tm_yolov5s_timvx.cpp @@ -20,14 +20,13 @@ /* * Copyright (c) 2021, OPEN AI LAB * Author: xwwang@openailab.com + * Author: stevenwudi@fiture.com */ -#include +#include #include #include #include -#include -#include #include #include #include @@ -36,6 +35,7 @@ #include "tengine_c_api.h" #include "tengine_operations.h" + struct Object { cv::Rect_ rect; @@ -133,116 +133,17 @@ static void nms_sorted_bboxes(const std::vector& faceobjects, std::vecto } } -void get_input_data_focas_uint8(const char* image_file, uint8_t* input_data, int img_h, int img_w, const float* mean, - const float* scale, float input_scale, int zero_point) -{ - cv::Mat sample = cv::imread(image_file, 1); - cv::Mat img; - - if (sample.channels() == 1) - cv::cvtColor(sample, img, cv::COLOR_GRAY2RGB); - else - cv::cvtColor(sample, img, cv::COLOR_BGR2RGB); - - /* letterbox process */ - float letterbox_size = img_h; - int resize_h = 0; - int resize_w = 0; - if (img.rows > img.cols) - { - resize_h = letterbox_size; - resize_w = int(img.cols * (letterbox_size / img.rows)); - } - else - { - resize_h = int(img.rows * (letterbox_size / img.cols)); - resize_w = letterbox_size; - } - - cv::resize(img, img, cv::Size(resize_w, resize_h)); - img.convertTo(img, CV_32FC3); - cv::Mat resize_img(letterbox_size, letterbox_size, CV_32FC3, - cv::Scalar(0.5/scale[0] + mean[0], 0.5/scale[1] + mean[1], 0.5/ scale[2] + mean[2])); - int dh = int((letterbox_size - resize_h) / 2); - int dw = int((letterbox_size - resize_w) / 2); - - for (int h = 0; h < resize_h; h++) - { - for (int w = 0; w < resize_w; w++) - { - for (int c = 0; c < 3; ++c) - { - int in_index = h * resize_w * 3 + w * 3 + c; - int out_index = (dh + h) * letterbox_size * 3 + (dw + w) * 3 + c; - - (( float* )resize_img.data)[out_index] = (( float* )img.data)[in_index]; - } - } - } - - resize_img.convertTo(resize_img, CV_32FC3); - float* img_data = (float* )resize_img.data; - float* input_temp = (float* )malloc(3 * letterbox_size * letterbox_size * sizeof(float)); - - /* nhwc to nchw */ - for (int h = 0; h < img_h; h++) - { - for (int w = 0; w < img_w; w++) - { - for (int c = 0; c < 3; c++) - { - int in_index = h * img_w * 3 + w * 3 + c; - int out_index = c * img_h * img_w + h * img_w + w; - input_temp[out_index] = (img_data[in_index] - mean[c]) * scale[c]; - } - } - } - - /* focus process */ - int input_size = letterbox_size / 2; - for (int i = 0; i < 2; i++) - { - for (int g = 0; g < 2; g++) - { - for (int c = 0; c < 3; c++) - { - for (int w = 0; w < input_size; w++) - { - for (int h = 0; h < input_size; h++) - { - int in_index = i + g * letterbox_size + c * letterbox_size * letterbox_size + w * 2 * letterbox_size + h * 2; - int out_index = i * 2 * 3 * input_size * input_size + - g * 3 * input_size * input_size + - c * input_size * input_size + - w * input_size + - h; - - /* quant to uint8 */ - int udata = (round)(input_temp[in_index] / input_scale + ( float )zero_point); - if (udata > 255) - udata = 255; - else if (udata < 0) - udata = 0; - - input_data[out_index] = udata; - } - } - } - } - } - - free(input_temp); -} -static void generate_proposals(int stride, const float* feat, float prob_threshold, std::vector& objects) +static void generate_proposals(int stride, const float* feat, float prob_threshold, std::vector& objects, + int letterbox_cols, int letterbox_rows) { static float anchors[18] = {10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326}; int anchor_num = 3; - int feat_w = 640 / stride; - int feat_h = 640 / stride; + int feat_w = letterbox_cols / stride; + int feat_h = letterbox_rows / stride; int cls_num = 80; - int anchor_group = 0; + int anchor_group; if(stride == 8) anchor_group = 1; if(stride == 16) @@ -305,15 +206,15 @@ static void generate_proposals(int stride, const float* feat, float prob_thresh static void draw_objects(const cv::Mat& bgr, const std::vector& objects) { static const char* class_names[] = { - "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", - "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", - "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", - "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", - "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", - "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", - "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", - "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", - "hair drier", "toothbrush" + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", + "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", + "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", + "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", + "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", + "hair drier", "toothbrush" }; cv::Mat image = bgr.clone(); @@ -347,27 +248,115 @@ static void draw_objects(const cv::Mat& bgr, const std::vector& objects) cv::Scalar(0, 0, 0)); } - cv::imwrite("yolov5_timvx_out.jpg", image); + cv::imwrite("yolov5_timvx_letterbox_out.jpg", image); } void show_usage() { fprintf( - stderr, - "[Usage]: [-h]\n [-m model_file] [-i image_file] [-r repeat_count] [-t thread_count] \n"); + stderr, + "[Usage]: [-h]\n [-m model_file] [-i image_file] [-r repeat_count] [-t thread_count]\n"); +} + +void get_input_data_focus_uint8(const char* image_file, uint8_t* input_data, int letterbox_rows, int letterbox_cols, const float* mean, + const float* scale, float input_scale, int zero_point) +{ + cv::Mat sample = cv::imread(image_file, 1); + cv::Mat img; + + if (sample.channels() == 1) + cv::cvtColor(sample, img, cv::COLOR_GRAY2RGB); + else + cv::cvtColor(sample, img, cv::COLOR_BGR2RGB); + + /* letterbox process to support different letterbox size */ + float scale_letterbox; + int resize_rows; + int resize_cols; + if ((letterbox_rows * 1.0 / img.rows) < (letterbox_cols * 1.0 / img.cols)) { + scale_letterbox = letterbox_rows * 1.0 / img.rows; + } else { + scale_letterbox = letterbox_cols * 1.0 / img.cols; + } + resize_cols = int(scale_letterbox * img.cols); + resize_rows = int(scale_letterbox * img.rows); + + cv::resize(img, img, cv::Size(resize_cols, resize_rows)); + img.convertTo(img, CV_32FC3); + // Generate a gray image for letterbox using opencv + cv::Mat img_new(letterbox_cols, letterbox_rows, CV_32FC3,cv::Scalar(0.5/scale[0] + mean[0], 0.5/scale[1] + mean[1], 0.5/ scale[2] + mean[2])); + int top = (letterbox_rows - resize_rows) / 2; + int bot = (letterbox_rows - resize_rows + 1) / 2; + int left = (letterbox_cols - resize_cols) / 2; + int right = (letterbox_cols - resize_cols + 1) / 2; + // Letterbox filling + cv::copyMakeBorder(img, img_new, top, bot, left, right, cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); + + img_new.convertTo(img_new, CV_32FC3); + float* img_data = (float* )img_new.data; + float* input_temp = (float* )malloc(3 * letterbox_cols * letterbox_rows * sizeof(float)); + + /* nhwc to nchw */ + for (int h = 0; h < letterbox_rows; h++) + { + for (int w = 0; w < letterbox_cols; w++) + { + for (int c = 0; c < 3; c++) + { + int in_index = h * letterbox_cols * 3 + w * 3 + c; + int out_index = c * letterbox_rows * letterbox_cols + h * letterbox_cols + w; + input_temp[out_index] = (img_data[in_index] - mean[c]) * scale[c]; + } + } + } + + /* focus process */ + for (int i = 0; i < 2; i++) // corresponding to rows + { + for (int g = 0; g < 2; g++) // corresponding to cols + { + for (int c = 0; c < 3; c++) + { + for (int h = 0; h < letterbox_rows/2; h++) + { + for (int w = 0; w < letterbox_cols/2; w++) + { + int in_index = i + g * letterbox_cols + c * letterbox_cols * letterbox_rows + + h * 2 * letterbox_cols + w * 2; + int out_index = i * 2 * 3 * (letterbox_cols/2) * (letterbox_rows/2) + + g * 3 * (letterbox_cols/2) * (letterbox_rows/2) + + c * (letterbox_cols/2) * (letterbox_rows/2) + + h * (letterbox_cols/2) + + w; + + /* quant to uint8 */ + int udata = (round)(input_temp[in_index] / input_scale + ( float )zero_point); + if (udata > 255) + udata = 255; + else if (udata < 0) + udata = 0; + + input_data[out_index] = udata; + } + } + } + } + } + free(input_temp); } int main(int argc, char* argv[]) { const char* model_file = nullptr; const char* image_file = nullptr; - int letterbox_size = 640; - int img_h = letterbox_size; - int img_w = letterbox_size; int img_c = 3; const float mean[3] = {0, 0, 0}; const float scale[3] = {0.003921, 0.003921, 0.003921}; + // set default letterbox size + int letterbox_rows = 640; + int letterbox_cols = 640; + int repeat_count = 1; int num_thread = 1; @@ -383,10 +372,10 @@ int main(int argc, char* argv[]) image_file = optarg; break; case 'r': - repeat_count = std::strtoul(optarg, nullptr, 10); + repeat_count = atoi(optarg); break; case 't': - num_thread = std::strtoul(optarg, nullptr, 10); + num_thread = atoi(optarg); break; case 'h': show_usage(); @@ -454,8 +443,8 @@ int main(int argc, char* argv[]) return -1; } - int img_size = img_h * img_w * img_c; - int dims[] = {1, 12, int(img_h / 2), int(img_w / 2)}; + int img_size = letterbox_rows * letterbox_cols * img_c; + int dims[] = {1, 12, int(letterbox_rows / 2), int(letterbox_cols / 2)}; uint8_t* input_data = ( uint8_t* )malloc(img_size * sizeof(uint8_t)); tensor_t input_tensor = get_graph_input_tensor(graph, 0, 0); @@ -487,8 +476,8 @@ int main(int argc, char* argv[]) /* prepare process input data, set the data mem to input tensor */ float input_scale = 0.f; int input_zero_point = 0; - get_tensor_quant_param(input_tensor, &input_scale, &input_zero_point, 1); - get_input_data_focas_uint8(image_file, input_data, img_h, img_w, mean, scale, input_scale, input_zero_point); + get_tensor_quant_param(input_tensor, &input_scale, &input_zero_point, 1); + get_input_data_focus_uint8(image_file, input_data, letterbox_rows, letterbox_cols, mean, scale, input_scale, input_zero_point); /* run graph */ double min_time = DBL_MAX; @@ -548,7 +537,7 @@ int main(int argc, char* argv[]) for (int c = 0; c < p16_count; c++) { p16_data[c] = (( float )p16_data_u8[c] - ( float )p16_zero_point) * p16_scale; - } + } uint8_t* p32_data_u8 = ( uint8_t* )get_tensor_buffer(p32_output); float* p32_data = ( float* )malloc(sizeof(float) * p32_count); @@ -567,11 +556,11 @@ int main(int argc, char* argv[]) std::vector objects32; std::vector objects; - generate_proposals(32, p32_data, prob_threshold, objects32); + generate_proposals(32, p32_data, prob_threshold, objects32, letterbox_cols, letterbox_rows); proposals.insert(proposals.end(), objects32.begin(), objects32.end()); - generate_proposals(16, p16_data, prob_threshold, objects16); + generate_proposals(16, p16_data, prob_threshold, objects16, letterbox_cols, letterbox_rows); proposals.insert(proposals.end(), objects16.begin(), objects16.end()); - generate_proposals( 8, p8_data, prob_threshold, objects8); + generate_proposals( 8, p8_data, prob_threshold, objects8, letterbox_cols, letterbox_rows); proposals.insert(proposals.end(), objects8.begin(), objects8.end()); qsort_descent_inplace(proposals); @@ -579,28 +568,23 @@ int main(int argc, char* argv[]) nms_sorted_bboxes(proposals, picked, nms_threshold); /* yolov5 draw the result */ - int raw_h = img.rows; - int raw_w = img.cols; - float lb = letterbox_size; - int h0 = 0; - int w0 = 0; - if ( img.rows > img.cols) - { - h0 = lb; - w0 = int(img.cols * (lb / img.rows)); - } - else - { - h0 = int(img.rows * (lb / img.cols)); - w0 = lb; + float scale_letterbox; + int resize_rows; + int resize_cols; + if ((letterbox_rows * 1.0 / img.rows) < (letterbox_cols * 1.0 / img.cols)) { + scale_letterbox = letterbox_rows * 1.0 / img.rows; + } else { + scale_letterbox = letterbox_cols * 1.0 / img.cols; } + resize_cols = int(scale_letterbox * img.cols); + resize_rows = int(scale_letterbox * img.rows); - int tmp_h = (lb - h0) / 2; - int tmp_w = (lb - w0) / 2; + int tmp_h = (letterbox_rows - resize_rows) / 2; + int tmp_w = (letterbox_cols - resize_cols) / 2; - float ratio_x = (float)raw_w / w0; - float ratio_y = (float)raw_h / h0; + float ratio_x = (float)img.rows / resize_rows; + float ratio_y = (float)img.cols / resize_cols; int count = picked.size(); fprintf(stderr, "detection num: %d\n",count); @@ -619,10 +603,10 @@ int main(int argc, char* argv[]) x1 = (x1 - tmp_w) * ratio_x; y1 = (y1 - tmp_h) * ratio_y; - x0 = std::max(std::min(x0, (float)(raw_w - 1)), 0.f); - y0 = std::max(std::min(y0, (float)(raw_h - 1)), 0.f); - x1 = std::max(std::min(x1, (float)(raw_w - 1)), 0.f); - y1 = std::max(std::min(y1, (float)(raw_h - 1)), 0.f); + x0 = std::max(std::min(x0, (float)(img.cols - 1)), 0.f); + y0 = std::max(std::min(y0, (float)(img.rows - 1)), 0.f); + x1 = std::max(std::min(x1, (float)(img.cols - 1)), 0.f); + y1 = std::max(std::min(y1, (float)(img.rows - 1)), 0.f); objects[i].rect.x = x0; objects[i].rect.y = y0; -- GitLab