diff --git a/fluid/PaddleCV/yolov3/box_utils.py b/fluid/PaddleCV/yolov3/box_utils.py index 404da6dcdecc9e9edfef65fb742157d48000e309..ddd97071f6235cdd81eca8670c0ddad3f0e438fd 100644 --- a/fluid/PaddleCV/yolov3/box_utils.py +++ b/fluid/PaddleCV/yolov3/box_utils.py @@ -26,10 +26,6 @@ from matplotlib import pyplot as plt from PIL import Image -def sigmoid(x): - """Perform sigmoid to input numpy array""" - return 1.0 / (1.0 + np.exp(-1.0 * x)) - def coco_anno_box_to_center_relative(box, img_height, img_width): """ Convert COCO annotations box with format [x1, y1, w, h] to @@ -93,7 +89,7 @@ def box_iou_xywh(box1, box2): inter_area = inter_w * inter_h b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) - + return inter_area / (b1_area + b2_area - inter_area) def box_iou_xyxy(box1, box2): @@ -115,32 +111,8 @@ def box_iou_xyxy(box1, box2): inter_area = inter_w * inter_h b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) - - return inter_area / (b1_area + b2_area - inter_area) -def rescale_box_in_input_image(boxes, im_shape, input_size): - """Scale (x1, x2, y1, y2) box of yolo output to input image""" - h, w = im_shape - # max_dim = max(h , w) - # boxes = boxes * max_dim / input_size - # dim_diff = np.abs(h - w) - # pad = dim_diff // 2 - # if h <= w: - # boxes[:, 1] -= pad - # boxes[:, 3] -= pad - # else: - # boxes[:, 0] -= pad - # boxes[:, 2] -= pad - fx = w / input_size - fy = h / input_size - boxes[:, 0] *= fx - boxes[:, 1] *= fy - boxes[:, 2] *= fx - boxes[:, 3] *= fy - boxes[boxes<0] = 0 - boxes[:, 2][boxes[:, 2] > (w - 1)] = w - 1 - boxes[:, 3][boxes[:, 3] > (h - 1)] = h - 1 - return boxes + return inter_area / (b1_area + b2_area - inter_area) def box_crop(boxes, labels, scores, crop, img_shape): x, y, w, h = map(float, crop) @@ -169,161 +141,6 @@ def box_crop(boxes, labels, scores, crop, img_shape): return boxes, labels, scores, mask.sum() -def get_yolo_detection(preds, anchors, class_num, img_width, img_height): - """Get yolo box, confidence score, class label from Darknet53 output""" - preds_n = np.array(preds) - n, c, h, w = preds_n.shape - anchor_num = len(anchors) // 2 - preds_n = preds_n.reshape([n, anchor_num, class_num + 5, h, w]) \ - .transpose((0, 1, 3, 4, 2)) - preds_n[:, :, :, :, :2] = sigmoid(preds_n[:, :, :, :, :2]) - preds_n[:, :, :, :, 4:] = sigmoid(preds_n[:, :, :, :, 4:]) - - pred_boxes = preds_n[:, :, :, :, :4] - pred_confs = preds_n[:, :, :, :, 4] - pred_scores = preds_n[:, :, :, :, 5:] * np.expand_dims(pred_confs, axis=4) - - grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) - grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) - anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors), 2)] - anchors_s = np.array([(an_w, an_h) for an_w, an_h in anchors]) - anchor_w = anchors_s[:, 0:1].reshape((1, anchor_num, 1, 1)) - anchor_h = anchors_s[:, 1:2].reshape((1, anchor_num, 1, 1)) - - pred_boxes[:, :, :, :, 0] += grid_x - pred_boxes[:, :, :, :, 1] += grid_y - pred_boxes[:, :, :, :, 2] = np.exp(pred_boxes[:, :, :, :, 2]) * anchor_w - pred_boxes[:, :, :, :, 3] = np.exp(pred_boxes[:, :, :, :, 3]) * anchor_h - - pred_boxes[:, :, :, :, 0] = pred_boxes[:, :, :, :, 0] * img_width / w - pred_boxes[:, :, :, :, 1] = pred_boxes[:, :, :, :, 1] * img_height / h - pred_boxes[:, :, :, :, 2] = pred_boxes[:, :, :, :, 2] - pred_boxes[:, :, :, :, 3] = pred_boxes[:, :, :, :, 3] - - pred_boxes = box_xywh_to_xyxy(pred_boxes) - pred_boxes = np.tile(np.expand_dims(pred_boxes, axis=4), (1, 1, 1, 1, class_num, 1)) - pred_labels = np.zeros_like(pred_scores) + np.arange(class_num) - - return ( - pred_boxes.reshape((n, -1, 4)), - pred_scores.reshape((n, -1)), - pred_labels.reshape((n, -1)), - ) - -def get_all_yolo_pred(outputs, yolo_anchors, yolo_classes, input_shape): - all_pred_boxes = [] - all_pred_scores = [] - all_pred_labels = [] - for output, anchors, classes in zip(outputs, yolo_anchors, yolo_classes): - pred_boxes, pred_scores, pred_labels = get_yolo_detection(output, anchors, classes, input_shape[0], input_shape[1]) - all_pred_boxes.append(pred_boxes) - all_pred_labels.append(pred_labels) - all_pred_scores.append(pred_scores) - pred_boxes = np.concatenate(all_pred_boxes, axis=1) - pred_scores = np.concatenate(all_pred_scores, axis=1) - pred_labels = np.concatenate(all_pred_labels, axis=1) - - return (pred_boxes, pred_scores, pred_labels) - -def calc_nms_box_new(pred_boxes, pred_scores, pred_labels, valid_thresh=0.01, nms_thresh=0.4, nms_topk=400, nms_posk=100): - output_boxes = np.empty((0, 4)) - output_scores = np.empty(0) - output_labels = np.empty(0) - for boxes, labels, scores in zip(pred_boxes, pred_labels, pred_scores): - valid_mask = scores > valid_thresh - boxes = boxes[valid_mask] - scores = scores[valid_mask] - labels = labels[valid_mask] - - score_sort_index = np.argsort(scores)[::-1] - boxes = boxes[score_sort_index][:nms_topk] - scores = scores[score_sort_index][:nms_topk] - labels = labels[score_sort_index][:nms_topk] - - for c in np.unique(labels): - c_mask = labels == c - c_boxes = boxes[c_mask] - c_scores = scores[c_mask] - - detect_boxes = [] - detect_scores = [] - detect_labels = [] - while c_boxes.shape[0]: - detect_boxes.append(c_boxes[0]) - detect_scores.append(c_scores[0]) - detect_labels.append(c) - if c_boxes.shape[0] == 1: - break - iou = box_iou_xyxy(detect_boxes[-1].reshape((1, 4)), c_boxes[1:]) - c_boxes = c_boxes[1:][iou < nms_thresh] - c_scores = c_scores[1:][iou < nms_thresh] - - output_boxes = np.append(output_boxes, detect_boxes, axis=0) - output_scores = np.append(output_scores, detect_scores) - output_labels = np.append(output_labels, detect_labels) - return (output_boxes, output_scores, output_labels) - - -def calc_nms_box(pred_boxes, pred_confs, pred_labels, im_shape, input_size, valid_thresh=0.8, nms_thresh=0.4, nms_topk=400, nms_posk=100): - """ - Removes detections which confidence score under valid_thresh and perform - Non-Maximun Suppression to filtered boxes - """ - _, box_num, class_num = pred_labels.shape - pred_boxes = box_xywh_to_xyxy(pred_boxes) - output_boxes = np.empty((0, 4)) - output_scores = np.empty(0) - output_labels = np.empty((0)) - for i, (boxes, confs, classes) in enumerate(zip(pred_boxes, pred_confs, pred_labels)): - conf_mask = confs > valid_thresh - if conf_mask.sum() == 0: - continue - boxes = boxes[conf_mask] - classes = classes[conf_mask] - confs = confs[conf_mask] - - conf_sort_index = np.argsort(confs)[::-1] - boxes = boxes[conf_sort_index][:nms_topk] - classes = classes[conf_sort_index][:nms_topk] - confs = confs[conf_sort_index][:nms_topk] - cls_score = np.max(classes, axis=1) - cls_pred = np.argmax(classes, axis=1) - - for c in np.unique(cls_pred): - c_mask = cls_pred == c - c_confs = confs[c_mask] - c_boxes = boxes[c_mask] - c_scores = cls_score[c_mask] - c_score_index = np.argsort(c_scores) - c_boxes_s = c_boxes[c_score_index[::-1]] - c_confs_s = c_confs[c_score_index[::-1]] - c_scores_s = c_scores[c_score_index[::-1]] - - detect_boxes = [] - detect_scores = [] - detect_labels = [] - while c_boxes_s.shape[0]: - detect_boxes.append(c_boxes_s[0]) - detect_scores.append(c_scores_s[0]) - detect_labels.append(c) - if c_boxes_s.shape[0] == 1: - break - iou = box_iou_xyxy(detect_boxes[-1].reshape((1, 4)), c_boxes_s[1:]) - c_boxes_s = c_boxes_s[1:][iou < nms_thresh] - c_confs_s = c_confs_s[1:][iou < nms_thresh] - c_scores_s = c_scores_s[1:][iou < nms_thresh] - - output_boxes = np.append(output_boxes, detect_boxes, axis=0) - output_scores = np.append(output_scores, detect_scores) - output_labels = np.append(output_labels, detect_labels) - - output_boxes = output_boxes[:nms_posk] - output_scores = output_scores[:nms_posk] - output_labels = output_labels[:nms_posk] - - output_boxes = rescale_box_in_input_image(output_boxes, im_shape, input_size) - return (output_boxes, output_scores, output_labels) - def draw_boxes_on_image(image_path, boxes, scores, labels, label_names, score_thresh=0.5): image = np.array(Image.open(image_path)) plt.figure() diff --git a/fluid/PaddleCV/yolov3/eval.py b/fluid/PaddleCV/yolov3/eval.py index 004ddf23c2999329a1b27b3f24ab56f16cf2c297..ae5d29b089b36fcc7034fd32ccc40af8cef8b48d 100644 --- a/fluid/PaddleCV/yolov3/eval.py +++ b/fluid/PaddleCV/yolov3/eval.py @@ -20,7 +20,6 @@ import time import numpy as np import paddle import paddle.fluid as fluid -import box_utils import reader import models from utility import print_arguments, parse_args @@ -64,6 +63,8 @@ def eval(): def get_pred_result(boxes, scores, labels, im_id): result = [] for box, score, label in zip(boxes, scores, labels): + if score < 0.05: + continue x1, y1, x2, y2 = box w = x2 - x1 + 1 h = y2 - y1 + 1 @@ -72,41 +73,41 @@ def eval(): res = { 'image_id': im_id, 'category_id': label_ids[int(label)], - 'bbox': bbox, - 'score': score + 'bbox': map(float, bbox), + 'score': float(score) } result.append(res) return result dts_res = [] - fetch_list = outputs + fetch_list = [outputs] total_time = 0 for batch_id, batch_data in enumerate(test_reader()): start_time = time.time() batch_outputs = exe.run( fetch_list=[v.name for v in fetch_list], feed=feeder.feed(batch_data), - return_numpy=False) - for data, outputs in zip(batch_data, batch_outputs): - im_id = data[1] - im_shape = data[2] - pred_boxes, pred_scores, pred_labels = box_utils.get_all_yolo_pred( - batch_outputs, yolo_anchors, yolo_classes, (input_size, input_size)) - boxes, scores, labels = box_utils.calc_nms_box_new(pred_boxes, pred_scores, pred_labels, - cfg.valid_thresh, cfg.nms_thresh) - boxes = box_utils.rescale_box_in_input_image(boxes, im_shape, input_size) + return_numpy=False, + use_program_cache=True) + lod = batch_outputs[0].lod()[0] + nmsed_boxes = np.array(batch_outputs[0]) + if nmsed_boxes.shape[1] != 6: + continue + for i in range(len(lod) - 1): + im_id = batch_data[i][1] + start = lod[i] + end = lod[i + 1] + if start == end: + continue + nmsed_box = nmsed_boxes[start:end, :] + labels = nmsed_box[:, 0] + scores = nmsed_box[:, 1] + boxes = nmsed_box[:, 2:6] dts_res += get_pred_result(boxes, scores, labels, im_id) - end_time = time.time() - print("batch id: {}, time: {}".format(batch_id, end_time - start_time)) - total_time += (end_time - start_time) - if cfg.debug: - if '2014' in cfg.dataset: - img_name = "COCO_val2014_{:012d}.jpg".format(im_id) - box_utils.draw_boxes_on_image(os.path.join("./dataset/coco/val2014", img_name), boxes, scores, labels, label_names) - if '2017' in cfg.dataset: - img_name = "{:012d}.jpg".format(im_id) - box_utils.draw_boxes_on_image(os.path.join("./dataset/coco/val2017", img_name), boxes, scores, labels, label_names) + end_time = time.time() + print("batch id: {}, time: {}".format(batch_id, end_time - start_time)) + total_time += end_time - start_time with open("yolov3_result.json", 'w') as outfile: json.dump(dts_res, outfile) diff --git a/fluid/PaddleCV/yolov3/infer.py b/fluid/PaddleCV/yolov3/infer.py index b67a8deb5d0d299e82d224832635e725336f67b5..1a5aa3b3f64c4cdb11b7f03054cd56559a35b113 100644 --- a/fluid/PaddleCV/yolov3/infer.py +++ b/fluid/PaddleCV/yolov3/infer.py @@ -34,7 +34,8 @@ def infer(): fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist) # yapf: enable feeder = fluid.DataFeeder(place=place, feed_list=model.feeds()) - fetch_list = outputs + fetch_list = [outputs] + # fetch_list = outputs image_names = [] if cfg.image_name is not None: image_names.append(cfg.image_name) @@ -50,13 +51,14 @@ def infer(): outputs = exe.run( fetch_list=[v.name for v in fetch_list], feed=feeder.feed(data), - return_numpy=True) + return_numpy=False) + bboxes = np.array(outputs[0]) + if bboxes.shape[1] != 6: + print("No object found in {}".format(image_name)) + labels = bboxes[:, 0].astype('int32') + scores = bboxes[:, 1].astype('float32') + boxes = bboxes[:, 2:].astype('float32') - pred_boxes, pred_scores, pred_labels = box_utils.get_all_yolo_pred(outputs, yolo_anchors, - yolo_classes, (input_size, input_size)) - boxes, scores, labels = box_utils.calc_nms_box_new(pred_boxes, pred_scores, pred_labels, - cfg.valid_thresh, cfg.nms_thresh) - boxes = box_utils.rescale_box_in_input_image(boxes, im_shape, input_size) path = os.path.join(cfg.image_path, image_name) box_utils.draw_boxes_on_image(path, boxes, scores, labels, label_names, cfg.draw_thresh) diff --git a/fluid/PaddleCV/yolov3/models.py b/fluid/PaddleCV/yolov3/models.py index 63a5299105dcf44c7ff88d987bd78a012c1ff338..2dd5f093c192ad3010737dda1ddabc0a416fe2ae 100644 --- a/fluid/PaddleCV/yolov3/models.py +++ b/fluid/PaddleCV/yolov3/models.py @@ -99,6 +99,8 @@ class YOLOv3(object): self.use_random = use_random self.outputs = [] self.losses = [] + self.boxes = [] + self.scores = [] self.downsample = 32 def build_model(self): @@ -213,7 +215,19 @@ class YOLOv3(object): # use_label_smooth=False, name="yolo_loss"+str(i)) self.losses.append(fluid.layers.reduce_mean(loss)) - self.downsample //= 2 + else: + boxes, scores = fluid.layers.yolo_box( + x=out, + img_size=self.im_shape, + anchors=mask_anchors, + class_num=class_num, + conf_thresh=cfg.valid_thresh, + downsample_ratio=self.downsample, + name="yolo_box"+str(i)) + self.boxes.append(boxes) + self.scores.append(fluid.layers.transpose(scores, perm=[0, 2, 1])) + + self.downsample //= 2 layer_outputs.append(out) @@ -221,7 +235,17 @@ class YOLOv3(object): return sum(self.losses) def get_pred(self): - return self.outputs + yolo_boxes = fluid.layers.concat(self.boxes, axis=1) + yolo_scores = fluid.layers.concat(self.scores, axis=2) + return fluid.layers.multiclass_nms( + bboxes=yolo_boxes, + scores=yolo_scores, + score_threshold=cfg.valid_thresh, + nms_top_k=cfg.nms_topk, + keep_top_k=cfg.nms_posk, + nms_threshold=cfg.nms_thresh, + background_label=-1, + name="multiclass_nms") def get_yolo_anchors(self): return self.yolo_anchors diff --git a/fluid/PaddleCV/yolov3/reader.py b/fluid/PaddleCV/yolov3/reader.py index cb6ccdda3d37b246df89fa6580433bb346962fe9..63e3bec7276d64be5ef90e8a9a004f6140568b34 100644 --- a/fluid/PaddleCV/yolov3/reader.py +++ b/fluid/PaddleCV/yolov3/reader.py @@ -156,7 +156,7 @@ class DataSetReader(object): h, w, _ = im.shape im_scale_x = size / float(w) im_scale_y = size / float(h) - out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_LINEAR) + out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC) mean = np.array(mean).reshape((1, 1, -1)) std = np.array(std).reshape((1, 1, -1)) out_img = (out_img / 255.0 - mean) / std