未验证 提交 5e19955b 编写于 作者: C cnn 提交者: GitHub

[dev] inference support bs > 1 (#3003)

* bs>1 for YOLO
上级 fd494657
...@@ -28,6 +28,8 @@ python tools/export_model.py -c configs/yolov3/yolov3_mobilenet_v1_roadsign.yml ...@@ -28,6 +28,8 @@ python tools/export_model.py -c configs/yolov3/yolov3_mobilenet_v1_roadsign.yml
* C++部署 支持`CPU``GPU``XPU`环境,支持,windows、linux系统,支持NV Jetson嵌入式设备上部署。参考文档[C++部署](cpp/README.md) * C++部署 支持`CPU``GPU``XPU`环境,支持,windows、linux系统,支持NV Jetson嵌入式设备上部署。参考文档[C++部署](cpp/README.md)
* PaddleDetection支持TensorRT加速,相关文档请参考[TensorRT预测部署教程](TENSOR_RT.md) * PaddleDetection支持TensorRT加速,相关文档请参考[TensorRT预测部署教程](TENSOR_RT.md)
**注意:** Paddle预测库版本需要>=2.1,batch_size>1仅支持YOLOv3和PP-YOLO。
## 2.PaddleServing部署 ## 2.PaddleServing部署
### 2.1 导出模型 ### 2.1 导出模型
......
...@@ -50,7 +50,7 @@ std::vector<int> GenerateColorMap(int num_class); ...@@ -50,7 +50,7 @@ std::vector<int> GenerateColorMap(int num_class);
// Visualiztion Detection Result // Visualiztion Detection Result
cv::Mat VisualizeResult(const cv::Mat& img, cv::Mat VisualizeResult(const cv::Mat& img,
const std::vector<ObjectResult>& results, const std::vector<ObjectResult>& results,
const std::vector<std::string>& lable_list, const std::vector<std::string>& lables,
const std::vector<int>& colormap, const std::vector<int>& colormap,
const bool is_rbox); const bool is_rbox);
...@@ -93,11 +93,12 @@ class ObjectDetector { ...@@ -93,11 +93,12 @@ class ObjectDetector {
const std::string& run_mode = "fluid"); const std::string& run_mode = "fluid");
// Run predictor // Run predictor
void Predict(const cv::Mat& im, void Predict(const std::vector<cv::Mat> imgs,
const double threshold = 0.5, const double threshold = 0.5,
const int warmup = 0, const int warmup = 0,
const int repeats = 1, const int repeats = 1,
std::vector<ObjectResult>* result = nullptr, std::vector<ObjectResult>* result = nullptr,
std::vector<int>* bbox_num = nullptr,
std::vector<double>* times = nullptr); std::vector<double>* times = nullptr);
// Get Model Label list // Get Model Label list
...@@ -120,14 +121,16 @@ class ObjectDetector { ...@@ -120,14 +121,16 @@ class ObjectDetector {
void Preprocess(const cv::Mat& image_mat); void Preprocess(const cv::Mat& image_mat);
// Postprocess result // Postprocess result
void Postprocess( void Postprocess(
const cv::Mat& raw_mat, const std::vector<cv::Mat> mats,
std::vector<ObjectResult>* result, std::vector<ObjectResult>* result,
std::vector<int> bbox_num,
bool is_rbox); bool is_rbox);
std::shared_ptr<Predictor> predictor_; std::shared_ptr<Predictor> predictor_;
Preprocessor preprocessor_; Preprocessor preprocessor_;
ImageBlob inputs_; ImageBlob inputs_;
std::vector<float> output_data_; std::vector<float> output_data_;
std::vector<int> out_bbox_num_data_;
float threshold_; float threshold_;
ConfigPaser config_; ConfigPaser config_;
std::vector<int> image_shape_; std::vector<int> image_shape_;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <numeric> #include <numeric>
#include <sys/types.h> #include <sys/types.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <math.h>
#ifdef _WIN32 #ifdef _WIN32
#include <direct.h> #include <direct.h>
...@@ -37,6 +38,7 @@ ...@@ -37,6 +38,7 @@
DEFINE_string(model_dir, "", "Path of inference model"); DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(image_file, "", "Path of input image"); DEFINE_string(image_file, "", "Path of input image");
DEFINE_string(image_dir, "", "Dir of input image, `image_file` has a higher priority."); DEFINE_string(image_dir, "", "Dir of input image, `image_file` has a higher priority.");
DEFINE_int32(batch_size, 1, "batch_size");
DEFINE_string(video_file, "", "Path of input video, `video_file` or `camera_id` has a highest priority."); DEFINE_string(video_file, "", "Path of input video, `video_file` or `camera_id` has a highest priority.");
DEFINE_int32(camera_id, -1, "Device id of camera to predict"); DEFINE_int32(camera_id, -1, "Device id of camera to predict");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
...@@ -189,6 +191,7 @@ void PredictVideo(const std::string& video_path, ...@@ -189,6 +191,7 @@ void PredictVideo(const std::string& video_path,
} }
std::vector<PaddleDetection::ObjectResult> result; std::vector<PaddleDetection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times; std::vector<double> det_times;
auto labels = det->GetLabelList(); auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size()); auto colormap = PaddleDetection::GenerateColorMap(labels.size());
...@@ -200,8 +203,9 @@ void PredictVideo(const std::string& video_path, ...@@ -200,8 +203,9 @@ void PredictVideo(const std::string& video_path,
if (frame.empty()) { if (frame.empty()) {
break; break;
} }
std::vector<cv::Mat> imgs;
det->Predict(frame, 0.5, 0, 1, &result, &det_times); imgs.push_back(frame);
det->Predict(imgs, 0.5, 0, 1, &result, &bbox_num, &det_times);
for (const auto& item : result) { for (const auto& item : result) {
if (item.rect.size() > 6){ if (item.rect.size() > 6){
is_rbox = true; is_rbox = true;
...@@ -238,70 +242,107 @@ void PredictVideo(const std::string& video_path, ...@@ -238,70 +242,107 @@ void PredictVideo(const std::string& video_path,
video_out.release(); video_out.release();
} }
void PredictImage(const std::vector<std::string> all_img_list, void PredictImage(const std::vector<std::string> all_img_paths,
const int batch_size,
const double threshold, const double threshold,
const bool run_benchmark, const bool run_benchmark,
PaddleDetection::ObjectDetector* det, PaddleDetection::ObjectDetector* det,
const std::string& output_dir = "output") { const std::string& output_dir = "output") {
std::vector<double> det_t = {0, 0, 0}; std::vector<double> det_t = {0, 0, 0};
for (auto image_file : all_img_list) { int steps = ceil(float(all_img_paths.size()) / batch_size);
// Open input image as an opencv cv::Mat object printf("total images = %d, batch_size = %d, total steps = %d\n",
cv::Mat im = cv::imread(image_file, 1); all_img_paths.size(), batch_size, steps);
for (int idx = 0; idx < steps; idx++) {
std::vector<cv::Mat> batch_imgs;
int left_image_cnt = all_img_paths.size() - idx * batch_size;
if (left_image_cnt > batch_size) {
left_image_cnt = batch_size;
}
for (int bs = 0; bs < left_image_cnt; bs++) {
std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
cv::Mat im = cv::imread(image_file_path, 1);
batch_imgs.insert(batch_imgs.end(), im);
}
// Store all detected result // Store all detected result
std::vector<PaddleDetection::ObjectResult> result; std::vector<PaddleDetection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times; std::vector<double> det_times;
bool is_rbox = false; bool is_rbox = false;
if (run_benchmark) { if (run_benchmark) {
det->Predict(im, threshold, 10, 10, &result, &det_times); det->Predict(batch_imgs, threshold, 10, 10, &result, &bbox_num, &det_times);
} else { } else {
det->Predict(im, 0.5, 0, 1, &result, &det_times); det->Predict(batch_imgs, 0.5, 0, 1, &result, &bbox_num, &det_times);
for (const auto& item : result) { // get labels and colormap
if (item.rect.size() > 6){ auto labels = det->GetLabelList();
is_rbox = true; auto colormap = PaddleDetection::GenerateColorMap(labels.size());
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id, int item_start_idx = 0;
item.confidence, for (int i = 0; i < left_image_cnt; i++) {
item.rect[0], std::cout << all_img_paths.at(idx * batch_size + i) << "result" << std::endl;
item.rect[1], if (bbox_num[i] <= 1) {
item.rect[2], continue;
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
} }
else{ for (int j = 0; j < bbox_num[i]; j++) {
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n", PaddleDetection::ObjectResult item = result[item_start_idx + j];
item.class_id, if (item.rect.size() > 6){
item.confidence, is_rbox = true;
item.rect[0], printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.rect[1], item.class_id,
item.rect[2], item.confidence,
item.rect[3]); item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
}
else{
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
} }
item_start_idx = item_start_idx + bbox_num[i];
} }
// Visualization result // Visualization result
auto labels = det->GetLabelList(); int bbox_idx = 0;
auto colormap = PaddleDetection::GenerateColorMap(labels.size()); for (int bs = 0; bs < batch_imgs.size(); bs++) {
cv::Mat vis_img = PaddleDetection::VisualizeResult( if (bbox_num[bs] <= 1) {
im, result, labels, colormap, is_rbox); continue;
std::vector<int> compression_params; }
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY); cv::Mat im = batch_imgs[bs];
compression_params.push_back(95); std::vector<PaddleDetection::ObjectResult> im_result;
std::string output_path(output_dir); for (int k = 0; k < bbox_num[bs]; k++) {
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) { im_result.push_back(result[bbox_idx+k]);
output_path += OS_PATH_SEP; }
bbox_idx += bbox_num[bs];
cv::Mat vis_img = PaddleDetection::VisualizeResult(
im, im_result, labels, colormap, is_rbox);
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
std::string output_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
output_path += OS_PATH_SEP;
}
std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
output_path += image_file_path.substr(image_file_path.find_last_of('/') + 1);
cv::imwrite(output_path, vis_img, compression_params);
printf("Visualized output saved as %s\n", output_path.c_str());
} }
;
output_path += image_file.substr(image_file.find_last_of('/') + 1);
cv::imwrite(output_path, vis_img, compression_params);
printf("Visualized output saved as %s\n", output_path.c_str());
} }
det_t[0] += det_times[0]; det_t[0] += det_times[0];
det_t[1] += det_times[1]; det_t[1] += det_times[1];
det_t[2] += det_times[2]; det_t[2] += det_times[2];
} }
PrintBenchmarkLog(det_t, all_img_list.size()); PrintBenchmarkLog(det_t, all_img_paths.size());
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
...@@ -329,13 +370,17 @@ int main(int argc, char** argv) { ...@@ -329,13 +370,17 @@ int main(int argc, char** argv) {
if (!PathExists(FLAGS_output_dir)) { if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir); MkDirs(FLAGS_output_dir);
} }
std::vector<std::string> all_img_list; std::vector<std::string> all_imgs;
if (!FLAGS_image_file.empty()) { if (!FLAGS_image_file.empty()) {
all_img_list.push_back(FLAGS_image_file); all_imgs.push_back(FLAGS_image_file);
if (FLAGS_batch_size > 1) {
std::cout << "batch_size should be 1, when image_file is not None" << std::endl;
FLAGS_batch_size = 1;
}
} else { } else {
GetAllFiles((char *)FLAGS_image_dir.c_str(), all_img_list); GetAllFiles((char *)FLAGS_image_dir.c_str(), all_imgs);
} }
PredictImage(all_img_list, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir); PredictImage(all_imgs, FLAGS_batch_size, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir);
} }
return 0; return 0;
} }
...@@ -93,7 +93,7 @@ void ObjectDetector::LoadModel(const std::string& model_dir, ...@@ -93,7 +93,7 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
// Visualiztion MaskDetector results // Visualiztion MaskDetector results
cv::Mat VisualizeResult(const cv::Mat& img, cv::Mat VisualizeResult(const cv::Mat& img,
const std::vector<ObjectResult>& results, const std::vector<ObjectResult>& results,
const std::vector<std::string>& lable_list, const std::vector<std::string>& lables,
const std::vector<int>& colormap, const std::vector<int>& colormap,
const bool is_rbox=false) { const bool is_rbox=false) {
cv::Mat vis_img = img.clone(); cv::Mat vis_img = img.clone();
...@@ -101,7 +101,7 @@ cv::Mat VisualizeResult(const cv::Mat& img, ...@@ -101,7 +101,7 @@ cv::Mat VisualizeResult(const cv::Mat& img,
// Configure color and text size // Configure color and text size
std::ostringstream oss; std::ostringstream oss;
oss << std::setiosflags(std::ios::fixed) << std::setprecision(4); oss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
oss << lable_list[results[i].class_id] << " "; oss << lables[results[i].class_id] << " ";
oss << results[i].confidence; oss << results[i].confidence;
std::string text = oss.str(); std::string text = oss.str();
int c1 = colormap[3 * results[i].class_id + 0]; int c1 = colormap[3 * results[i].class_id + 0];
...@@ -121,20 +121,20 @@ cv::Mat VisualizeResult(const cv::Mat& img, ...@@ -121,20 +121,20 @@ cv::Mat VisualizeResult(const cv::Mat& img,
if (is_rbox) if (is_rbox)
{ {
// Draw object, text, and background // Draw object, text, and background
for (int k=0; k<4; k++) for (int k = 0; k < 4; k++)
{ {
cv::Point pt1 = cv::Point(results[i].rect[(k*2)%8], cv::Point pt1 = cv::Point(results[i].rect[(k * 2) % 8],
results[i].rect[(k*2+1)%8]); results[i].rect[(k * 2 + 1) % 8]);
cv::Point pt2 = cv::Point(results[i].rect[(k*2+2)%8], cv::Point pt2 = cv::Point(results[i].rect[(k * 2 + 2) % 8],
results[i].rect[(k*2+3)%8]); results[i].rect[(k * 2 + 3) % 8]);
cv::line(vis_img, pt1, pt2, roi_color, 2); cv::line(vis_img, pt1, pt2, roi_color, 2);
} }
} }
else else
{ {
int w = results[i].rect[1] - results[i].rect[0]; int w = results[i].rect[2] - results[i].rect[0];
int h = results[i].rect[3] - results[i].rect[2]; int h = results[i].rect[3] - results[i].rect[1];
cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[2], w, h); cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[1], w, h);
// Draw roi object, text, and background // Draw roi object, text, and background
cv::rectangle(vis_img, roi, roi_color, 2); cv::rectangle(vis_img, roi, roi_color, 2);
} }
...@@ -144,7 +144,7 @@ cv::Mat VisualizeResult(const cv::Mat& img, ...@@ -144,7 +144,7 @@ cv::Mat VisualizeResult(const cv::Mat& img,
// Configure text background // Configure text background
cv::Rect text_back = cv::Rect(results[i].rect[0], cv::Rect text_back = cv::Rect(results[i].rect[0],
results[i].rect[2] - text_size.height, results[i].rect[1] - text_size.height,
text_size.width, text_size.width,
text_size.height); text_size.height);
// Draw text, and background // Draw text, and background
...@@ -168,76 +168,100 @@ void ObjectDetector::Preprocess(const cv::Mat& ori_im) { ...@@ -168,76 +168,100 @@ void ObjectDetector::Preprocess(const cv::Mat& ori_im) {
} }
void ObjectDetector::Postprocess( void ObjectDetector::Postprocess(
const cv::Mat& raw_mat, const std::vector<cv::Mat> mats,
std::vector<ObjectResult>* result, std::vector<ObjectResult>* result,
std::vector<int> bbox_num,
bool is_rbox=false) { bool is_rbox=false) {
result->clear(); result->clear();
int rh = 1; int start_idx = 0;
int rw = 1; for (int im_id = 0; im_id < bbox_num.size(); im_id++) {
if (config_.arch_ == "Face") { cv::Mat raw_mat = mats[im_id];
rh = raw_mat.rows; for (int j = start_idx; j < start_idx+bbox_num[im_id]; j++) {
rw = raw_mat.cols; int rh = 1;
} int rw = 1;
if (config_.arch_ == "Face") {
rh = raw_mat.rows;
rw = raw_mat.cols;
}
if (is_rbox) if (is_rbox) {
{ for (int j = 0; j < bbox_num[im_id]; ++j) {
int total_size = output_data_.size() / 10; // Class id
for (int j = 0; j < total_size; ++j) { int class_id = static_cast<int>(round(output_data_[0 + j * 10]));
// Class id // Confidence score
int class_id = static_cast<int>(round(output_data_[0 + j * 10])); float score = output_data_[1 + j * 10];
// Confidence score int x1 = (output_data_[2 + j * 10] * rw);
float score = output_data_[1 + j * 10]; int y1 = (output_data_[3 + j * 10] * rh);
int x1 = (output_data_[2 + j * 10] * rw); int x2 = (output_data_[4 + j * 10] * rw);
int y1 = (output_data_[3 + j * 10] * rh); int y2 = (output_data_[5 + j * 10] * rh);
int x2 = (output_data_[4 + j * 10] * rw); int x3 = (output_data_[6 + j * 10] * rw);
int y2 = (output_data_[5 + j * 10] * rh); int y3 = (output_data_[7 + j * 10] * rh);
int x3 = (output_data_[6 + j * 10] * rw); int x4 = (output_data_[8 + j * 10] * rw);
int y3 = (output_data_[7 + j * 10] * rh); int y4 = (output_data_[9 + j * 10] * rh);
int x4 = (output_data_[8 + j * 10] * rw); if (score > threshold_ && class_id > -1) {
int y4 = (output_data_[9 + j * 10] * rh); ObjectResult result_item;
if (score > threshold_ && class_id > -1) { result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4};
ObjectResult result_item; result_item.class_id = class_id;
result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4}; result_item.confidence = score;
result_item.class_id = class_id; result->push_back(result_item);
result_item.confidence = score; }
result->push_back(result_item); }
} }
} else {
} for (int j = 0; j < bbox_num[im_id]; ++j) {
else // Class id
{ int class_id = static_cast<int>(round(output_data_[0 + j * 6]));
int total_size = output_data_.size() / 6; // Confidence score
for (int j = 0; j < total_size; ++j) { float score = output_data_[1 + j * 6];
// Class id int xmin = (output_data_[2 + j * 6] * rw);
int class_id = static_cast<int>(round(output_data_[0 + j * 6])); int ymin = (output_data_[3 + j * 6] * rh);
// Confidence score int xmax = (output_data_[4 + j * 6] * rw);
float score = output_data_[1 + j * 6]; int ymax = (output_data_[5 + j * 6] * rh);
int xmin = (output_data_[2 + j * 6] * rw); int wd = xmax - xmin;
int ymin = (output_data_[3 + j * 6] * rh); int hd = ymax - ymin;
int xmax = (output_data_[4 + j * 6] * rw); if (score > threshold_ && class_id > -1) {
int ymax = (output_data_[5 + j * 6] * rh); ObjectResult result_item;
int wd = xmax - xmin; result_item.rect = {xmin, ymin, xmax, ymax};
int hd = ymax - ymin; result_item.class_id = class_id;
if (score > threshold_ && class_id > -1) { result_item.confidence = score;
ObjectResult result_item; result->push_back(result_item);
result_item.rect = {xmin, xmax, ymin, ymax}; }
result_item.class_id = class_id; }
result_item.confidence = score;
result->push_back(result_item);
} }
} }
start_idx += bbox_num[im_id];
} }
} }
void ObjectDetector::Predict(const cv::Mat& im, void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
const double threshold, const double threshold,
const int warmup, const int warmup,
const int repeats, const int repeats,
std::vector<ObjectResult>* result, std::vector<ObjectResult>* result,
std::vector<int>* bbox_num,
std::vector<double>* times) { std::vector<double>* times) {
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int batch_size = imgs.size();
// in_data_batch
std::vector<float> in_data_all;
std::vector<float> im_shape_all(batch_size * 2);
std::vector<float> scale_factor_all(batch_size * 2);
// Preprocess image // Preprocess image
Preprocess(im); for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
cv::Mat im = imgs.at(bs_idx);
Preprocess(im);
im_shape_all[bs_idx * 2] = inputs_.im_shape_[0];
im_shape_all[bs_idx * 2 + 1] = inputs_.im_shape_[1];
scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0];
scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1];
// TODO: reduce cost time
in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end());
}
// Prepare input tensor // Prepare input tensor
auto input_names = predictor_->GetInputNames(); auto input_names = predictor_->GetInputNames();
for (const auto& tensor_name : input_names) { for (const auto& tensor_name : input_names) {
...@@ -245,14 +269,14 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -245,14 +269,14 @@ void ObjectDetector::Predict(const cv::Mat& im,
if (tensor_name == "image") { if (tensor_name == "image") {
int rh = inputs_.in_net_shape_[0]; int rh = inputs_.in_net_shape_[0];
int rw = inputs_.in_net_shape_[1]; int rw = inputs_.in_net_shape_[1];
in_tensor->Reshape({1, 3, rh, rw}); in_tensor->Reshape({batch_size, 3, rh, rw});
in_tensor->CopyFromCpu(inputs_.im_data_.data()); in_tensor->CopyFromCpu(in_data_all.data());
} else if (tensor_name == "im_shape") { } else if (tensor_name == "im_shape") {
in_tensor->Reshape({1, 2}); in_tensor->Reshape({batch_size, 2});
in_tensor->CopyFromCpu(inputs_.im_shape_.data()); in_tensor->CopyFromCpu(im_shape_all.data());
} else if (tensor_name == "scale_factor") { } else if (tensor_name == "scale_factor") {
in_tensor->Reshape({1, 2}); in_tensor->Reshape({batch_size, 2});
in_tensor->CopyFromCpu(inputs_.scale_factor_.data()); in_tensor->CopyFromCpu(scale_factor_all.data());
} }
} }
auto preprocess_end = std::chrono::steady_clock::now(); auto preprocess_end = std::chrono::steady_clock::now();
...@@ -266,10 +290,6 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -266,10 +290,6 @@ void ObjectDetector::Predict(const cv::Mat& im,
std::vector<int> output_shape = out_tensor->shape(); std::vector<int> output_shape = out_tensor->shape();
// Calculate output length // Calculate output length
int output_size = 1; int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
}
if (output_size < 6) { if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl; std::cerr << "[WARNING] No object detected." << std::endl;
} }
...@@ -286,6 +306,8 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -286,6 +306,8 @@ void ObjectDetector::Predict(const cv::Mat& im,
auto output_names = predictor_->GetOutputNames(); auto output_names = predictor_->GetOutputNames();
auto out_tensor = predictor_->GetOutputHandle(output_names[0]); auto out_tensor = predictor_->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = out_tensor->shape(); std::vector<int> output_shape = out_tensor->shape();
auto out_bbox_num = predictor_->GetOutputHandle(output_names[1]);
std::vector<int> out_bbox_num_shape = out_bbox_num->shape();
// Calculate output length // Calculate output length
int output_size = 1; int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) { for (int j = 0; j < output_shape.size(); ++j) {
...@@ -298,11 +320,23 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -298,11 +320,23 @@ void ObjectDetector::Predict(const cv::Mat& im,
} }
output_data_.resize(output_size); output_data_.resize(output_size);
out_tensor->CopyToCpu(output_data_.data()); out_tensor->CopyToCpu(output_data_.data());
int out_bbox_num_size = 1;
for (int j = 0; j < out_bbox_num_shape.size(); ++j) {
out_bbox_num_size *= out_bbox_num_shape[j];
}
out_bbox_num_data_.resize(out_bbox_num_size);
out_bbox_num->CopyToCpu(out_bbox_num_data_.data());
} }
auto inference_end = std::chrono::steady_clock::now(); auto inference_end = std::chrono::steady_clock::now();
auto postprocess_start = std::chrono::steady_clock::now(); auto postprocess_start = std::chrono::steady_clock::now();
// Postprocessing result // Postprocessing result
Postprocess(im, result, is_rbox); Postprocess(imgs, result, out_bbox_num_data_, is_rbox);
bbox_num->clear();
for (int k=0; k<out_bbox_num_data_.size(); k++) {
int tmp = out_bbox_num_data_[k];
bbox_num->push_back(tmp);
}
auto postprocess_end = std::chrono::steady_clock::now(); auto postprocess_end = std::chrono::steady_clock::now();
std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start; std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
......
...@@ -129,7 +129,6 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) { ...@@ -129,7 +129,6 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) {
static_cast<float>(im->rows), static_cast<float>(im->rows),
static_cast<float>(im->cols), static_cast<float>(im->cols),
}; };
} }
......
...@@ -21,6 +21,7 @@ from functools import reduce ...@@ -21,6 +21,7 @@ from functools import reduce
from PIL import Image from PIL import Image
import cv2 import cv2
import numpy as np import numpy as np
import math
import paddle import paddle
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
...@@ -85,18 +86,29 @@ class Detector(object): ...@@ -85,18 +86,29 @@ class Detector(object):
self.det_times = Timer() self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
def preprocess(self, im): def preprocess(self, image_list):
preprocess_ops = [] preprocess_ops = []
for op_info in self.pred_config.preprocess_infos: for op_info in self.pred_config.preprocess_infos:
new_op_info = op_info.copy() new_op_info = op_info.copy()
op_type = new_op_info.pop('type') op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info)) preprocess_ops.append(eval(op_type)(**new_op_info))
im, im_info = preprocess(im, preprocess_ops,
self.pred_config.input_shape) input_im_lst = []
inputs = create_inputs(im, im_info) input_im_info_lst = []
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops,
self.pred_config.input_shape)
input_im_lst.append(im)
input_im_info_lst.append(im_info)
inputs = create_inputs(input_im_lst, input_im_info_lst)
return inputs return inputs
def postprocess(self, np_boxes, np_masks, inputs, threshold=0.5): def postprocess(self,
np_boxes,
np_masks,
inputs,
np_boxes_num,
threshold=0.5):
# postprocess output of predictor # postprocess output of predictor
results = {} results = {}
if self.pred_config.arch in ['Face']: if self.pred_config.arch in ['Face']:
...@@ -108,14 +120,15 @@ class Detector(object): ...@@ -108,14 +120,15 @@ class Detector(object):
np_boxes[:, 4] *= h np_boxes[:, 4] *= h
np_boxes[:, 5] *= w np_boxes[:, 5] *= w
results['boxes'] = np_boxes results['boxes'] = np_boxes
results['boxes_num'] = np_boxes_num
if np_masks is not None: if np_masks is not None:
results['masks'] = np_masks results['masks'] = np_masks
return results return results
def predict(self, image, threshold=0.5, warmup=0, repeats=1): def predict(self, image_list, threshold=0.5, warmup=0, repeats=1):
''' '''
Args: Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2 image_list (list): ,list of image
threshold (float): threshold of predicted box' score threshold (float): threshold of predicted box' score
Returns: Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
...@@ -124,7 +137,7 @@ class Detector(object): ...@@ -124,7 +137,7 @@ class Detector(object):
shape: [N, im_h, im_w] shape: [N, im_h, im_w]
''' '''
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image) inputs = self.preprocess(image_list)
np_boxes, np_masks = None, None np_boxes, np_masks = None, None
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for i in range(len(input_names)): for i in range(len(input_names)):
...@@ -146,6 +159,8 @@ class Detector(object): ...@@ -146,6 +159,8 @@ class Detector(object):
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0]) boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu() np_boxes = boxes_tensor.copy_to_cpu()
boxes_num = self.predictor.get_output_handle(output_names[1])
np_boxes_num = boxes_num.copy_to_cpu()
if self.pred_config.mask: if self.pred_config.mask:
masks_tensor = self.predictor.get_output_handle(output_names[2]) masks_tensor = self.predictor.get_output_handle(output_names[2])
np_masks = masks_tensor.copy_to_cpu() np_masks = masks_tensor.copy_to_cpu()
...@@ -155,12 +170,12 @@ class Detector(object): ...@@ -155,12 +170,12 @@ class Detector(object):
results = [] results = []
if reduce(lambda x, y: x * y, np_boxes.shape) < 6: if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
print('[WARNNING] No object detected.') print('[WARNNING] No object detected.')
results = {'boxes': np.array([])} results = {'boxes': np.array([]), 'boxes_num': [0]}
else: else:
results = self.postprocess( results = self.postprocess(
np_boxes, np_masks, inputs, threshold=threshold) np_boxes, np_masks, inputs, np_boxes_num, threshold=threshold)
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1 self.det_times.img_num += len(image_list)
return results return results
...@@ -249,21 +264,45 @@ class DetectorSOLOv2(Detector): ...@@ -249,21 +264,45 @@ class DetectorSOLOv2(Detector):
return dict(segm=np_segms, label=np_label, score=np_score) return dict(segm=np_segms, label=np_label, score=np_score)
def create_inputs(im, im_info): def create_inputs(imgs, im_info):
"""generate input for different model type """generate input for different model type
Args: Args:
im (np.ndarray): image (np.ndarray) im (np.ndarray): image (np.ndarray)
im_info (dict): info of image im_info (dict): info of image
model_arch (str): model type
Returns: Returns:
inputs (dict): input of model inputs (dict): input of model
""" """
inputs = {} inputs = {}
inputs['image'] = np.array((im, )).astype('float32')
inputs['im_shape'] = np.array((im_info['im_shape'], )).astype('float32')
inputs['scale_factor'] = np.array(
(im_info['scale_factor'], )).astype('float32')
im_shape = []
scale_factor = []
for e in im_info:
im_shape.append(np.array((e['im_shape'], )).astype('float32'))
scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))
origin_scale_factor = np.concatenate(scale_factor, axis=0)
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
max_shape_h = max([e[0] for e in imgs_shape])
max_shape_w = max([e[1] for e in imgs_shape])
padding_imgs = []
padding_imgs_shape = []
padding_imgs_scale = []
for img in imgs:
im_c, im_h, im_w = img.shape[:]
padding_im = np.zeros(
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
padding_imgs.append(padding_im)
padding_imgs_shape.append(
np.array([max_shape_h, max_shape_w]).astype('float32'))
rescale = [
float(max_shape_h) / float(im_h), float(max_shape_w) / float(im_w)
]
padding_imgs_scale.append(np.array(rescale).astype('float32'))
inputs['image'] = np.stack(padding_imgs, axis=0)
inputs['im_shape'] = np.stack(padding_imgs_shape, axis=0)
inputs['scale_factor'] = origin_scale_factor
return inputs return inputs
...@@ -426,15 +465,30 @@ def get_test_images(infer_dir, infer_img): ...@@ -426,15 +465,30 @@ def get_test_images(infer_dir, infer_img):
return images return images
def visualize(image_file, results, labels, output_dir='output/', threshold=0.5): def visualize(image_list, results, labels, output_dir='output/', threshold=0.5):
# visualize the predict result # visualize the predict result
im = visualize_box_mask(image_file, results, labels, threshold=threshold) start_idx = 0
img_name = os.path.split(image_file)[-1] for idx, image_file in enumerate(image_list):
if not os.path.exists(output_dir): im_bboxes_num = results['boxes_num'][idx]
os.makedirs(output_dir) im_results = {}
out_path = os.path.join(output_dir, img_name) if 'boxes' in results:
im.save(out_path, quality=95) im_results['boxes'] = results['boxes'][start_idx:start_idx +
print("save result to: " + out_path) im_bboxes_num, :]
if 'masks' in results:
im_results['masks'] = results['masks'][start_idx:start_idx +
im_bboxes_num, :]
if 'segm' in results:
im_results['segm'] = results['segm'][start_idx:start_idx +
im_bboxes_num, :]
start_idx += im_bboxes_num
im = visualize_box_mask(
image_file, im_results, labels, threshold=threshold)
img_name = os.path.split(image_file)[-1]
if not os.path.exists(output_dir):
os.makedirs(output_dir)
out_path = os.path.join(output_dir, img_name)
im.save(out_path, quality=95)
print("save result to: " + out_path)
def print_arguments(args): def print_arguments(args):
...@@ -444,19 +498,24 @@ def print_arguments(args): ...@@ -444,19 +498,24 @@ def print_arguments(args):
print('------------------------------------------') print('------------------------------------------')
def predict_image(detector, image_list): def predict_image(detector, image_list, batch_size=1):
for i, img_file in enumerate(image_list): batch_loop_cnt = math.ceil(float(len(image_list)) / batch_size)
for i in range(batch_loop_cnt):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(image_list))
batch_image_list = image_list[start_index:end_index]
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
detector.predict(img_file, FLAGS.threshold, warmup=10, repeats=10) detector.predict(
batch_image_list, FLAGS.threshold, warmup=10, repeats=10)
cm, gm, gu = get_current_memory_mb() cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm detector.cpu_mem += cm
detector.gpu_mem += gm detector.gpu_mem += gm
detector.gpu_util += gu detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file)) print('Test iter {}'.format(i))
else: else:
results = detector.predict(img_file, FLAGS.threshold) results = detector.predict(batch_image_list, FLAGS.threshold)
visualize( visualize(
img_file, batch_image_list,
results, results,
detector.pred_config.labels, detector.pred_config.labels,
output_dir=FLAGS.output_dir, output_dir=FLAGS.output_dir,
...@@ -535,8 +594,10 @@ def main(): ...@@ -535,8 +594,10 @@ def main():
predict_video(detector, FLAGS.camera_id) predict_video(detector, FLAGS.camera_id)
else: else:
# predict from image # predict from image
if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list) predict_image(detector, img_list, FLAGS.batch_size)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
detector.det_times.info(average=True) detector.det_times.info(average=True)
else: else:
......
...@@ -34,6 +34,8 @@ def argsparser(): ...@@ -34,6 +34,8 @@ def argsparser():
type=str, type=str,
default=None, default=None,
help="Dir of image file, `image_file` has a higher priority.") help="Dir of image file, `image_file` has a higher priority.")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch_size for infer.")
parser.add_argument( parser.add_argument(
"--video_file", "--video_file",
type=str, type=str,
......
...@@ -436,7 +436,7 @@ class Trainer(object): ...@@ -436,7 +436,7 @@ class Trainer(object):
image = visualize_results( image = visualize_results(
image, bbox_res, mask_res, segm_res, keypoint_res, image, bbox_res, mask_res, segm_res, keypoint_res,
int(outs['im_id']), catid2name, draw_threshold) int(im_id), catid2name, draw_threshold)
self.status['result_image'] = np.array(image.copy()) self.status['result_image'] = np.array(image.copy())
if self._compose_callback: if self._compose_callback:
self._compose_callback.on_step_end(self.status) self._compose_callback.on_step_end(self.status)
......
...@@ -83,11 +83,13 @@ class S2ANet(BaseArch): ...@@ -83,11 +83,13 @@ class S2ANet(BaseArch):
nms_pre = self.s2anet_bbox_post_process.nms_pre nms_pre = self.s2anet_bbox_post_process.nms_pre
pred_scores, pred_bboxes = self.s2anet_head.get_prediction(nms_pre) pred_scores, pred_bboxes = self.s2anet_head.get_prediction(nms_pre)
# post_process
pred_bboxes, bbox_num = self.s2anet_bbox_post_process(pred_scores, pred_bboxes, bbox_num = self.s2anet_bbox_post_process(pred_scores,
pred_bboxes) pred_bboxes)
# rescale the prediction back to origin image # rescale the prediction back to origin image
pred_bboxes = self.s2anet_bbox_post_process.get_pred( pred_bboxes = self.s2anet_bbox_post_process.get_pred(
pred_bboxes, bbox_num, im_shape, scale_factor) pred_bboxes, bbox_num, im_shape, scale_factor)
# output # output
output = {'bbox': pred_bboxes, 'bbox_num': bbox_num} output = {'bbox': pred_bboxes, 'bbox_num': bbox_num}
return output return output
......
...@@ -334,8 +334,11 @@ class RCNNBox(object): ...@@ -334,8 +334,11 @@ class RCNNBox(object):
self.num_classes = num_classes self.num_classes = num_classes
def __call__(self, bbox_head_out, rois, im_shape, scale_factor): def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
bbox_pred, cls_prob = bbox_head_out bbox_pred = bbox_head_out[0]
roi, rois_num = rois cls_prob = bbox_head_out[1]
roi = rois[0]
rois_num = rois[1]
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
scale_list = [] scale_list = []
origin_shape_list = [] origin_shape_list = []
......
...@@ -264,7 +264,6 @@ class S2ANetBBoxPostProcess(nn.Layer): ...@@ -264,7 +264,6 @@ class S2ANetBBoxPostProcess(nn.Layer):
bbox_num = self.fake_bbox_num bbox_num = self.fake_bbox_num
pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10]) pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10])
assert pred_cls_score_bbox.shape[1] == 10
return pred_cls_score_bbox, bbox_num return pred_cls_score_bbox, bbox_num
def get_pred(self, bboxes, bbox_num, im_shape, scale_factor): def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
...@@ -281,7 +280,6 @@ class S2ANetBBoxPostProcess(nn.Layer): ...@@ -281,7 +280,6 @@ class S2ANetBBoxPostProcess(nn.Layer):
including labels, scores and bboxes. The size of including labels, scores and bboxes. The size of
bboxes are corresponding to the original image. bboxes are corresponding to the original image.
""" """
assert bboxes.shape[1] == 10
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
origin_shape_list = [] origin_shape_list = []
...@@ -307,6 +305,7 @@ class S2ANetBBoxPostProcess(nn.Layer): ...@@ -307,6 +305,7 @@ class S2ANetBBoxPostProcess(nn.Layer):
pred_bbox = bboxes[:, 2:] pred_bbox = bboxes[:, 2:]
# rescale bbox to original image # rescale bbox to original image
pred_bbox = pred_bbox.reshape([-1, 8])
scaled_bbox = pred_bbox / scale_factor_list scaled_bbox = pred_bbox / scale_factor_list
origin_h = origin_shape_list[:, 0] origin_h = origin_shape_list[:, 0]
origin_w = origin_shape_list[:, 1] origin_w = origin_shape_list[:, 1]
......
...@@ -62,11 +62,11 @@ class RPNHead(nn.Layer): ...@@ -62,11 +62,11 @@ class RPNHead(nn.Layer):
Args: Args:
anchor_generator (dict): configure of anchor generation anchor_generator (dict): configure of anchor generation
rpn_target_assign (dict): configure of rpn targets assignment rpn_target_assign (dict): configure of rpn targets assignment
train_proposal (dict): configure of proposals generation train_proposal (dict): configure of proposals generation
at the stage of training at the stage of training
test_proposal (dict): configure of proposals generation test_proposal (dict): configure of proposals generation
at the stage of prediction at the stage of prediction
in_channel (int): channel of input feature maps which can be in_channel (int): channel of input feature maps which can be
derived by from_config derived by from_config
""" """
...@@ -156,31 +156,35 @@ class RPNHead(nn.Layer): ...@@ -156,31 +156,35 @@ class RPNHead(nn.Layer):
""" """
prop_gen = self.train_proposal if self.training else self.test_proposal prop_gen = self.train_proposal if self.training else self.test_proposal
im_shape = inputs['im_shape'] im_shape = inputs['im_shape']
rpn_rois_list = [[] for i in range(batch_size)]
rpn_prob_list = [[] for i in range(batch_size)] # Collect multi-level proposals for each batch
rpn_rois_num_list = [[] for i in range(batch_size)] # Get 'topk' of them as final output
bs_rois_collect = []
bs_rois_num_collect = []
# Generate proposals for each level and each batch. # Generate proposals for each level and each batch.
# Discard batch-computing to avoid sorting bbox cross different batches. # Discard batch-computing to avoid sorting bbox cross different batches.
for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas, anchors): for i in range(batch_size):
for i in range(batch_size): rpn_rois_list = []
rpn_prob_list = []
rpn_rois_num_list = []
for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
anchors):
rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = prop_gen( rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = prop_gen(
scores=rpn_score[i:i + 1], scores=rpn_score[i:i + 1],
bbox_deltas=rpn_delta[i:i + 1], bbox_deltas=rpn_delta[i:i + 1],
anchors=anchor, anchors=anchor,
im_shape=im_shape[i:i + 1]) im_shape=im_shape[i:i + 1])
if rpn_rois.shape[0] > 0: if rpn_rois.shape[0] > 0:
rpn_rois_list[i].append(rpn_rois) rpn_rois_list.append(rpn_rois)
rpn_prob_list[i].append(rpn_rois_prob) rpn_prob_list.append(rpn_rois_prob)
rpn_rois_num_list[i].append(rpn_rois_num) rpn_rois_num_list.append(rpn_rois_num)
# Collect multi-level proposals for each batch
# Get 'topk' of them as final output
rois_collect = []
rois_num_collect = []
for i in range(batch_size):
if len(scores) > 1: if len(scores) > 1:
rpn_rois = paddle.concat(rpn_rois_list[i]) rpn_rois = paddle.concat(rpn_rois_list)
rpn_prob = paddle.concat(rpn_prob_list[i]).flatten() rpn_prob = paddle.concat(rpn_prob_list).flatten()
if rpn_prob.shape[0] > post_nms_top_n: if rpn_prob.shape[0] > post_nms_top_n:
topk_prob, topk_inds = paddle.topk(rpn_prob, post_nms_top_n) topk_prob, topk_inds = paddle.topk(rpn_prob, post_nms_top_n)
topk_rois = paddle.gather(rpn_rois, topk_inds) topk_rois = paddle.gather(rpn_rois, topk_inds)
...@@ -188,17 +192,19 @@ class RPNHead(nn.Layer): ...@@ -188,17 +192,19 @@ class RPNHead(nn.Layer):
topk_rois = rpn_rois topk_rois = rpn_rois
topk_prob = rpn_prob topk_prob = rpn_prob
else: else:
topk_rois = rpn_rois_list[i][0] topk_rois = rpn_rois_list[0]
topk_prob = rpn_prob_list[i][0].flatten() topk_prob = rpn_prob_list[0].flatten()
rois_collect.append(topk_rois)
rois_num_collect.append(paddle.shape(topk_rois)[0]) bs_rois_collect.append(topk_rois)
rois_num_collect = paddle.concat(rois_num_collect) bs_rois_num_collect.append(paddle.shape(topk_rois)[0])
bs_rois_num_collect = paddle.concat(bs_rois_num_collect)
return rois_collect, rois_num_collect return bs_rois_collect, bs_rois_num_collect
def get_loss(self, pred_scores, pred_deltas, anchors, inputs): def get_loss(self, pred_scores, pred_deltas, anchors, inputs):
""" """
pred_scores (list[Tensor]): Multi-level scores prediction pred_scores (list[Tensor]): Multi-level scores prediction
pred_deltas (list[Tensor]): Multi-level deltas prediction pred_deltas (list[Tensor]): Multi-level deltas prediction
anchors (list[Tensor]): Multi-level anchors anchors (list[Tensor]): Multi-level anchors
inputs (dict): ground truth info, including im, gt_bbox, gt_score inputs (dict): ground truth info, including im, gt_bbox, gt_score
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册