未验证 提交 97373901 编写于 作者: S shangliang Xu 提交者: GitHub

[depoly] fix inconsistency between cpp and python (#4351)

上级 35a93b11
......@@ -128,27 +128,36 @@ static void MkDirs(const std::string& path) {
}
void PredictVideo(const std::string& video_path,
PaddleDetection::ObjectDetector* det) {
PaddleDetection::ObjectDetector* det,
const std::string& output_dir = "output") {
// Open video
cv::VideoCapture capture;
std::string video_out_name = "output.mp4";
if (FLAGS_camera_id != -1){
capture.open(FLAGS_camera_id);
}else{
capture.open(video_path.c_str());
video_out_name = video_path.substr(video_path.find_last_of(OS_PATH_SEP) + 1);
}
if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str());
return;
}
// Get Video info : resolution, fps
// Get Video info : resolution, fps, frame count
int video_width = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_WIDTH));
int video_height = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_HEIGHT));
int video_fps = static_cast<int>(capture.get(CV_CAP_PROP_FPS));
int video_frame_count = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_COUNT));
printf("fps: %d, frame_count: %d\n", video_fps, video_frame_count);
// Create VideoWriter for output
cv::VideoWriter video_out;
std::string video_out_path = "output.mp4";
std::string video_out_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
video_out_path += OS_PATH_SEP;
}
video_out_path += video_out_name;
video_out.open(video_out_path.c_str(),
0x00000021,
video_fps,
......@@ -166,7 +175,7 @@ void PredictVideo(const std::string& video_path,
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
// Capture all frames and do inference
cv::Mat frame;
int frame_id = 0;
int frame_id = 1;
bool is_rbox = false;
while (capture.read(frame)) {
if (frame.empty()) {
......@@ -174,8 +183,14 @@ void PredictVideo(const std::string& video_path,
}
std::vector<cv::Mat> imgs;
imgs.push_back(frame);
det->Predict(imgs, 0.5, 0, 1, &result, &bbox_num, &det_times);
printf("detect frame: %d\n", frame_id);
det->Predict(imgs, FLAGS_threshold, 0, 1, &result, &bbox_num, &det_times);
std::vector<PaddleDetection::ObjectResult> out_result;
for (const auto& item : result) {
if (item.confidence < FLAGS_threshold || item.class_id == -1) {
continue;
}
out_result.push_back(item);
if (item.rect.size() > 6){
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
......@@ -202,7 +217,7 @@ void PredictVideo(const std::string& video_path,
}
cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, result, labels, colormap, is_rbox);
frame, out_result, labels, colormap, is_rbox);
video_out.write(out_im);
frame_id += 1;
......@@ -337,12 +352,12 @@ int main(int argc, char** argv) {
FLAGS_trt_min_shape, FLAGS_trt_max_shape, FLAGS_trt_opt_shape,
FLAGS_trt_calib_mode);
// Do inference on input video or image
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
}
if (!FLAGS_video_file.empty() || FLAGS_camera_id != -1) {
PredictVideo(FLAGS_video_file, &det);
PredictVideo(FLAGS_video_file, &det, FLAGS_output_dir);
} else if (!FLAGS_image_file.empty() || !FLAGS_image_dir.empty()) {
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
}
std::vector<std::string> all_img_paths;
std::vector<cv::String> cv_all_img_paths;
if (!FLAGS_image_file.empty()) {
......
......@@ -128,27 +128,36 @@ static void MkDirs(const std::string& path) {
}
void PredictVideo(const std::string& video_path,
PaddleDetection::JDEDetector* mot) {
PaddleDetection::JDEDetector* mot,
const std::string& output_dir = "output") {
// Open video
cv::VideoCapture capture;
std::string video_out_name = "output.mp4";
if (FLAGS_camera_id != -1){
capture.open(FLAGS_camera_id);
}else{
capture.open(video_path.c_str());
video_out_name = video_path.substr(video_path.find_last_of(OS_PATH_SEP) + 1);
}
if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str());
return;
}
// Get Video info : resolution, fps
// Get Video info : resolution, fps, frame count
int video_width = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_WIDTH));
int video_height = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_HEIGHT));
int video_fps = static_cast<int>(capture.get(CV_CAP_PROP_FPS));
int video_frame_count = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_COUNT));
printf("fps: %d, frame_count: %d\n", video_fps, video_frame_count);
// Create VideoWriter for output
cv::VideoWriter video_out;
std::string video_out_path = "mot_output.mp4";
std::string video_out_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
video_out_path += OS_PATH_SEP;
}
video_out_path += video_out_name;
video_out.open(video_out_path.c_str(),
0x00000021,
video_fps,
......@@ -164,14 +173,15 @@ void PredictVideo(const std::string& video_path,
double times;
// Capture all frames and do inference
cv::Mat frame;
int frame_id = 0;
int frame_id = 1;
while (capture.read(frame)) {
if (frame.empty()) {
break;
}
std::vector<cv::Mat> imgs;
imgs.push_back(frame);
mot->Predict(imgs, 0.5, 0, 1, &result, &det_times);
printf("detect frame: %d\n", frame_id);
mot->Predict(imgs, FLAGS_threshold, 0, 1, &result, &det_times);
frame_id += 1;
times = std::accumulate(det_times.begin(), det_times.end(), 0) / frame_id;
......@@ -215,7 +225,9 @@ int main(int argc, char** argv) {
FLAGS_cpu_threads, FLAGS_run_mode, FLAGS_batch_size,FLAGS_gpu_id,
FLAGS_trt_min_shape, FLAGS_trt_max_shape, FLAGS_trt_opt_shape,
FLAGS_trt_calib_mode);
PredictVideo(FLAGS_video_file, &mot);
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
}
PredictVideo(FLAGS_video_file, &mot, FLAGS_output_dir);
return 0;
}
......@@ -138,27 +138,36 @@ static void MkDirs(const std::string& path) {
void PredictVideo(const std::string& video_path,
PaddleDetection::ObjectDetector* det,
PaddleDetection::KeyPointDetector* keypoint) {
PaddleDetection::KeyPointDetector* keypoint,
const std::string& output_dir = "output") {
// Open video
cv::VideoCapture capture;
std::string video_out_name = "output.mp4";
if (FLAGS_camera_id != -1){
capture.open(FLAGS_camera_id);
}else{
capture.open(video_path.c_str());
video_out_name = video_path.substr(video_path.find_last_of(OS_PATH_SEP) + 1);
}
if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str());
return;
}
// Get Video info : resolution, fps
// Get Video info : resolution, fps, frame count
int video_width = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_WIDTH));
int video_height = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_HEIGHT));
int video_fps = static_cast<int>(capture.get(CV_CAP_PROP_FPS));
int video_frame_count = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_COUNT));
printf("fps: %d, frame_count: %d\n", video_fps, video_frame_count);
// Create VideoWriter for output
cv::VideoWriter video_out;
std::string video_out_path = "output.mp4";
std::string video_out_path(output_dir);
if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
video_out_path += OS_PATH_SEP;
}
video_out_path += video_out_name;
video_out.open(video_out_path.c_str(),
0x00000021,
video_fps,
......@@ -184,7 +193,7 @@ void PredictVideo(const std::string& video_path,
std::vector<int> colormap_kpts = PaddleDetection::GenerateColorMap(20);
// Capture all frames and do inference
cv::Mat frame;
int frame_id = 0;
int frame_id = 1;
bool is_rbox = false;
while (capture.read(frame)) {
if (frame.empty()) {
......@@ -192,8 +201,14 @@ void PredictVideo(const std::string& video_path,
}
std::vector<cv::Mat> imgs;
imgs.push_back(frame);
det->Predict(imgs, 0.5, 0, 1, &result, &bbox_num, &det_times);
printf("detect frame: %d\n", frame_id);
det->Predict(imgs, FLAGS_threshold, 0, 1, &result, &bbox_num, &det_times);
std::vector<PaddleDetection::ObjectResult> out_result;
for (const auto& item : result) {
if (item.confidence < FLAGS_threshold || item.class_id == -1) {
continue;
}
out_result.push_back(item);
if (item.rect.size() > 6){
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
......@@ -221,9 +236,9 @@ void PredictVideo(const std::string& video_path,
if(keypoint)
{
int imsize = result.size();
int imsize = out_result.size();
for (int i=0; i<imsize; i++){
auto item = result[i];
auto item = out_result[i];
cv::Mat crop_img;
std::vector<double> keypoint_times;
std::vector<int> rect = {item.rect[0], item.rect[1], item.rect[2], item.rect[3]};
......@@ -239,7 +254,7 @@ void PredictVideo(const std::string& video_path,
if (imgs_kpts.size()==FLAGS_batch_size_keypoint || ((i==imsize-1)&&!imgs_kpts.empty()))
{
keypoint->Predict(imgs_kpts, center_bs, scale_bs, 0.5, 0, 1, &result_kpts, &keypoint_times);
keypoint->Predict(imgs_kpts, center_bs, scale_bs, FLAGS_threshold, 0, 1, &result_kpts, &keypoint_times);
imgs_kpts.clear();
center_bs.clear();
scale_bs.clear();
......@@ -251,7 +266,7 @@ void PredictVideo(const std::string& video_path,
else{
// Visualization result
cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, result, labels, colormap, is_rbox);
frame, out_result, labels, colormap, is_rbox);
video_out.write(out_im);
}
......@@ -450,12 +465,12 @@ int main(int argc, char** argv) {
FLAGS_trt_calib_mode, FLAGS_use_dark);
}
// Do inference on input video or image
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
}
if (!FLAGS_video_file.empty() || FLAGS_camera_id != -1) {
PredictVideo(FLAGS_video_file, &det, keypoint);
PredictVideo(FLAGS_video_file, &det, keypoint, FLAGS_output_dir);
} else if (!FLAGS_image_file.empty() || !FLAGS_image_dir.empty()) {
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
}
std::vector<std::string> all_img_paths;
std::vector<cv::String> cv_all_img_paths;
if (!FLAGS_image_file.empty()) {
......
......@@ -133,22 +133,24 @@ def topdown_unite_predict_video(detector,
topdown_keypoint_detector,
camera_id,
keypoint_batch_size=1):
video_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.splitext(os.path.basename(FLAGS.video_file))[
0] + '.mp4'
fps = 30
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# yapf: enable
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 0
while (1):
......@@ -156,7 +158,7 @@ def topdown_unite_predict_video(detector,
if not ret:
break
index += 1
print('detect frame:%d' % (index))
print('detect frame: %d' % (index))
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = detector.predict([frame2], FLAGS.det_threshold)
......
......@@ -664,30 +664,30 @@ def predict_image(detector, image_list, batch_size=1):
def predict_video(detector, camera_id):
video_out_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
fps = 30
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print('frame_count', frame_count)
video_out_name = os.path.split(FLAGS.video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# yapf: enable
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
out_path = os.path.join(FLAGS.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame:%d' % (index))
print('detect frame: %d' % (index))
index += 1
results = detector.predict([frame], FLAGS.threshold)
im = visualize_box_mask(
......
......@@ -284,28 +284,30 @@ def predict_image(detector, image_list):
def predict_video(detector, camera_id):
video_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
fps = 30
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# yapf: enable
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name + '.mp4')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame:%d' % (index))
print('detect frame: %d' % (index))
index += 1
results = detector.predict([frame], FLAGS.threshold)
im = draw_pose(
......
......@@ -212,24 +212,24 @@ def predict_image(detector, image_list):
def predict_video(detector, camera_id):
video_name = 'mot_output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'mot_output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
fps = 30
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print('frame_count', frame_count)
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# yapf: enable
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
......@@ -270,7 +270,7 @@ def predict_video(detector, camera_id):
write_mot_results(result_filename, [results[-1]])
frame_id += 1
print('detect frame:%d' % (frame_id))
print('detect frame: %d' % (frame_id))
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
......
......@@ -126,18 +126,18 @@ def mot_keypoint_unite_predict_video(mot_model,
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
fps = 30
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print('frame_count', frame_count)
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# yapf: enable
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer_mot = FPSTimer()
......@@ -195,7 +195,7 @@ def mot_keypoint_unite_predict_video(mot_model,
im = np.array(online_im)
frame_id += 1
print('detect frame:%d' % (frame_id))
print('detect frame: %d' % (frame_id))
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
......
......@@ -355,18 +355,18 @@ def predict_video(detector, reid_model, camera_id):
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
fps = 30
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print('frame_count', frame_count)
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# yapf: enable
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
......@@ -425,7 +425,7 @@ def predict_video(detector, reid_model, camera_id):
write_mot_results(result_filename, [result])
frame_id += 1
print('detect frame:%d' % (frame_id))
print('detect frame: %d' % (frame_id))
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册