From 7a614e769d779c1000df763ce28d0e7460fbfd66 Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Fri, 5 Jun 2020 18:59:43 +0800 Subject: [PATCH] fix the bug --- paddlex/cv/datasets/widerface.py | 1 + paddlex/cv/models/blazeface.py | 6 ++- paddlex/cv/models/utils/detection_eval.py | 10 +++++ paddlex/cv/nets/__init__.py | 1 + paddlex/cv/nets/blazenet.py | 17 +++---- paddlex/cv/nets/detection/blazeface.py | 36 ++++++--------- paddlex/cv/transforms/box_utils.py | 54 ++++++++++++++++++++++- paddlex/cv/transforms/det_transforms.py | 27 ++++++++---- paddlex/cv/transforms/ops.py | 5 ++- 9 files changed, 111 insertions(+), 46 deletions(-) diff --git a/paddlex/cv/datasets/widerface.py b/paddlex/cv/datasets/widerface.py index 45ea796..5e96234 100644 --- a/paddlex/cv/datasets/widerface.py +++ b/paddlex/cv/datasets/widerface.py @@ -86,6 +86,7 @@ class WIDERFACEDetection(VOCDetection): continue else: is_discard = True + print(img_file) im = cv2.imread(img_file) im_w = im.shape[1] im_h = im.shape[0] diff --git a/paddlex/cv/models/blazeface.py b/paddlex/cv/models/blazeface.py index 2ae1f3a..78eb2a0 100644 --- a/paddlex/cv/models/blazeface.py +++ b/paddlex/cv/models/blazeface.py @@ -22,6 +22,7 @@ import paddlex.utils.logging as logging import paddlex from .base import BaseAPI from collections import OrderedDict +from .utils.detection_eval import eval_results, bbox2out import copy class BlazeFace(BaseAPI): @@ -74,13 +75,14 @@ class BlazeFace(BaseAPI): def build_net(self, mode='train'): model = paddlex.cv.nets.detection.BlazeFace( backbone=self._get_backbone(self.backbone), + mode=mode, min_sizes=self.min_sizes, num_classes=self.num_classes, use_density_prior_box=self.use_density_prior_box, densities=self.densities, nms_threshold=self.nms_iou_threshold, nms_topk=self.nms_topk, - nms_keep_topk=self.nms_score_threshold, + nms_keep_topk=self.nms_keep_topk, score_threshold=self.nms_score_threshold, fixed_input_shape=self.fixed_input_shape) inputs = model.generate_inputs() @@ -263,6 +265,8 @@ class BlazeFace(BaseAPI): } res_im_id = [d[4] for d in data] res['im_id'] = (np.array(res_im_id), []) + res_im_shape = [d[5] for d in data] + res['im_shape'] = (np.array(res_im_shape), []) if metric == 'VOC': res_gt_box = [] res_gt_label = [] diff --git a/paddlex/cv/models/utils/detection_eval.py b/paddlex/cv/models/utils/detection_eval.py index b9dcdaa..a068aef 100644 --- a/paddlex/cv/models/utils/detection_eval.py +++ b/paddlex/cv/models/utils/detection_eval.py @@ -84,6 +84,16 @@ def eval_results(results, return box_ap_stats, eval_details +def clip_bbox(bbox, im_size=None): + h = 1. if im_size is None else im_size[0] + w = 1. if im_size is None else im_size[1] + xmin = max(min(bbox[0], w), 0.) + ymin = max(min(bbox[1], h), 0.) + xmax = max(min(bbox[2], w), 0.) + ymax = max(min(bbox[3], h), 0.) + return xmin, ymin, xmax, ymax + + def proposal_eval(results, coco_gt, outputfile, max_dets=(100, 300, 1000)): assert 'proposal' in results[0] assert outfile.endswith('.json') diff --git a/paddlex/cv/nets/__init__.py b/paddlex/cv/nets/__init__.py index b1441c5..7d3c1a4 100644 --- a/paddlex/cv/nets/__init__.py +++ b/paddlex/cv/nets/__init__.py @@ -24,6 +24,7 @@ from .xception import Xception from .densenet import DenseNet from .shufflenet_v2 import ShuffleNetV2 from .hrnet import HRNet +from .blazenet import BlazeNet def resnet18(input, num_classes=1000): diff --git a/paddlex/cv/nets/blazenet.py b/paddlex/cv/nets/blazenet.py index fac5a4e..121fb2a 100644 --- a/paddlex/cv/nets/blazenet.py +++ b/paddlex/cv/nets/blazenet.py @@ -19,10 +19,6 @@ from __future__ import print_function from paddle import fluid from paddle.fluid.param_attr import ParamAttr -from ppdet.experimental import mixed_precision_global_state -from ppdet.core.workspace import register - - class BlazeNet(object): """ @@ -147,7 +143,6 @@ class BlazeNet(object): use_pool = not stride == 1 use_double_block = double_channels is not None act = 'relu' if use_double_block else None - mixed_precision_enabled = mixed_precision_global_state() is not None if use_5x5kernel: conv_dw = self._conv_norm( @@ -157,7 +152,7 @@ class BlazeNet(object): stride=stride, padding=2, num_groups=in_channels, - use_cudnn=mixed_precision_enabled, + use_cudnn=True, name=name + "1_dw") else: conv_dw_1 = self._conv_norm( @@ -167,7 +162,7 @@ class BlazeNet(object): stride=1, padding=1, num_groups=in_channels, - use_cudnn=mixed_precision_enabled, + use_cudnn=True, name=name + "1_dw_1") conv_dw = self._conv_norm( input=conv_dw_1, @@ -176,7 +171,7 @@ class BlazeNet(object): stride=stride, padding=1, num_groups=in_channels, - use_cudnn=mixed_precision_enabled, + use_cudnn=True, name=name + "1_dw_2") conv_pw = self._conv_norm( @@ -196,7 +191,7 @@ class BlazeNet(object): num_filters=out_channels, stride=1, padding=2, - use_cudnn=mixed_precision_enabled, + use_cudnn=True, name=name + "2_dw") else: conv_dw_1 = self._conv_norm( @@ -206,7 +201,7 @@ class BlazeNet(object): stride=1, padding=1, num_groups=out_channels, - use_cudnn=mixed_precision_enabled, + use_cudnn=True, name=name + "2_dw_1") conv_dw = self._conv_norm( input=conv_dw_1, @@ -215,7 +210,7 @@ class BlazeNet(object): stride=1, padding=1, num_groups=out_channels, - use_cudnn=mixed_precision_enabled, + use_cudnn=True, name=name + "2_dw_2") conv_pw = self._conv_norm( diff --git a/paddlex/cv/nets/detection/blazeface.py b/paddlex/cv/nets/detection/blazeface.py index d838861..069fcb0 100644 --- a/paddlex/cv/nets/detection/blazeface.py +++ b/paddlex/cv/nets/detection/blazeface.py @@ -20,6 +20,7 @@ from collections import OrderedDict class BlazeFace: def __init__(self, backbone, + mode='train', min_sizes=[[16., 24.], [32., 48., 64., 80., 96., 128.]], max_sizes=None, steps=[8., 16.], @@ -33,8 +34,8 @@ class BlazeFace: nms_eta=1.0, fixed_input_shape=None): self.backbone = backbone + self.mode=mode self.num_classes = num_classes - self.output_decoder = output_decoder self.min_sizes = min_sizes self.max_sizes = max_sizes self.steps = steps @@ -130,24 +131,24 @@ class BlazeFace: inputs['image'] = fluid.data( dtype='float32', shape=[None, 3, None, None], name='image') if self.mode == 'train': - inputs['gt_box'] = fluid.data( - dtype='float32', shape=[None, None, 4], lod_level=1, name='gt_box') + inputs['gt_bbox'] = fluid.data( + dtype='float32', shape=[None, 4], lod_level=1, name='gt_bbox') inputs['gt_label'] = fluid.data( - dtype='int32', shape=[None, None], lod_level=1, name='gt_label') - inputs['im_size'] = fluid.data( - dtype='int32', shape=[None, 2], name='im_size') + dtype='int32', shape=[None, 1], lod_level=1, name='gt_label') elif self.mode == 'eval': - inputs['gt_box'] = fluid.data( - dtype='float32', shape=[None, None, 4], lod_level=1, name='gt_box') + inputs['gt_bbox'] = fluid.data( + dtype='float32', shape=[None, 4], lod_level=1, name='gt_bbox') inputs['gt_label'] = fluid.data( - dtype='int32', shape=[None, None], lod_level=1, name='gt_label') + dtype='int32', shape=[None, 1], lod_level=1, name='gt_label') inputs['is_difficult'] = fluid.data( dtype='int32', shape=[None, 1], lod_level=1, name='is_difficult') inputs['im_id'] = fluid.data( dtype='int32', shape=[None, 1], name='im_id') + inputs['im_shape'] = fluid.data( + dtype='int32', shape=[None, 2], name='im_shape') elif self.mode == 'test': - inputs['im_size'] = fluid.data( - dtype='int32', shape=[None, 2], name='im_size') + inputs['im_shape'] = fluid.data( + dtype='int32', shape=[None, 2], name='im_shape') return inputs @@ -156,22 +157,13 @@ class BlazeFace: if self.mode == 'train': gt_bbox = inputs['gt_bbox'] gt_label = inputs['gt_label'] - im_size = inputs['im_size'] - num_boxes = fluid.layers.shape(gt_box)[1] - im_size_wh = fluid.layers.reverse(im_size, axis=1) - whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1) - whwh = fluid.layers.unsqueeze(whwh, axes=[1]) - whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1]) - whwh = fluid.layers.cast(whwh, dtype='float32') - whwh.stop_gradient = True - normalized_box = fluid.layers.elementwise_div(gt_box, whwh) body_feats = self.backbone(image) locs, confs, box, box_var = self._multi_box_head( inputs=body_feats, image=image, num_classes=self.num_classes, use_density_prior_box=self.use_density_prior_box) - if mode == 'train': + if self.mode == 'train': loss = fluid.layers.ssd_loss( locs, confs, @@ -192,7 +184,7 @@ class BlazeFace: box_var, background_label=self.background_label, nms_threshold=self.nms_threshold, - nms_top_k=self.nms_keep_topk, + nms_top_k=self.nms_topk, keep_top_k=self.nms_keep_topk, score_threshold=self.score_threshold, nms_eta=self.nms_eta) diff --git a/paddlex/cv/transforms/box_utils.py b/paddlex/cv/transforms/box_utils.py index 23ef006..f9dc3da 100644 --- a/paddlex/cv/transforms/box_utils.py +++ b/paddlex/cv/transforms/box_utils.py @@ -459,4 +459,56 @@ def generate_sample_bbox_square(sampler, image_width, image_height): xmax = xmin + bbox_width ymax = ymin + bbox_height sampled_bbox = [xmin, ymin, xmax, ymax] - return sampled_bbox \ No newline at end of file + return sampled_bbox + + +def bbox_coverage(bbox1, bbox2): + inter_box = intersect_bbox(bbox1, bbox2) + intersect_size = bbox_area(inter_box) + + if intersect_size > 0: + bbox1_size = bbox_area(bbox1) + return intersect_size / bbox1_size + else: + return 0. + + +def meet_emit_constraint(src_bbox, sample_bbox): + center_x = (src_bbox[2] + src_bbox[0]) / 2 + center_y = (src_bbox[3] + src_bbox[1]) / 2 + if center_x >= sample_bbox[0] and \ + center_x <= sample_bbox[2] and \ + center_y >= sample_bbox[1] and \ + center_y <= sample_bbox[3]: + return True + return False + + +def is_overlap(object_bbox, sample_bbox): + if object_bbox[0] >= sample_bbox[2] or \ + object_bbox[2] <= sample_bbox[0] or \ + object_bbox[1] >= sample_bbox[3] or \ + object_bbox[3] <= sample_bbox[1]: + return False + else: + return True + + +def intersect_bbox(bbox1, bbox2): + if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \ + bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]: + intersection_box = [0.0, 0.0, 0.0, 0.0] + else: + intersection_box = [ + max(bbox1[0], bbox2[0]), max(bbox1[1], bbox2[1]), + min(bbox1[2], bbox2[2]), min(bbox1[3], bbox2[3]) + ] + return intersection_box + + +def clip_bbox(src_bbox): + src_bbox[0] = max(min(src_bbox[0], 1.0), 0.0) + src_bbox[1] = max(min(src_bbox[1], 1.0), 0.0) + src_bbox[2] = max(min(src_bbox[2], 1.0), 0.0) + src_bbox[3] = max(min(src_bbox[3], 1.0), 0.0) + return src_bbox \ No newline at end of file diff --git a/paddlex/cv/transforms/det_transforms.py b/paddlex/cv/transforms/det_transforms.py index 0f7c582..55ebb8d 100644 --- a/paddlex/cv/transforms/det_transforms.py +++ b/paddlex/cv/transforms/det_transforms.py @@ -1110,13 +1110,13 @@ class CropImageWithDataAchorSampling(DetTransform): gt_bbox = label_info['gt_bbox'] gt_bbox_tmp = gt_bbox.copy() for i in range(gt_bbox_tmp.shape[0]): - gt_bbox_tmp[i][0] = gt_bbox[i][0] / im_width - gt_bbox_tmp[i][1] = gt_bbox[i][1] / im_height - gt_bbox_tmp[i][2] = gt_bbox[i][2] / im_width - gt_bbox_tmp[i][3] = gt_bbox[i][3] / im_height + gt_bbox_tmp[i][0] = gt_bbox[i][0] / image_width + gt_bbox_tmp[i][1] = gt_bbox[i][1] / image_height + gt_bbox_tmp[i][2] = gt_bbox[i][2] / image_width + gt_bbox_tmp[i][3] = gt_bbox[i][3] / image_height gt_class = label_info['gt_class'] gt_score = None - if 'gt_score' in sample: + if 'gt_score' in label_info: gt_score = label_info['gt_score'] sampled_bbox = [] gt_bbox_tmp = gt_bbox_tmp.tolist() @@ -1505,13 +1505,22 @@ class ArrangeBlazeFace(DetTransform): 'Becasuse the im_info and label_info can not be None!') if len(label_info['gt_bbox']) != len(label_info['gt_class']): raise ValueError("gt num mismatch: bbox and class.") - outputs = (im, label_info['gt_bbox'], label_info['gt_class'], im_info['image_shape']) + gt_bbox = label_info['gt_bbox'] + im_shape = im_info['image_shape'] + im_height = im_shape[0] + im_width = im_shape[1] + for i in range(gt_bbox.shape[0]): + gt_bbox[i][0] = gt_bbox[i][0] / im_width + gt_bbox[i][1] = gt_bbox[i][1] / im_height + gt_bbox[i][2] = gt_bbox[i][2] / im_width + gt_bbox[i][3] = gt_bbox[i][3] / im_height + outputs = (im, gt_bbox, label_info['gt_class']) elif self.mode == 'eval': if im_info is None : raise TypeError( 'Cannot do ArrangeBlazeFace! ' + 'Becasuse the im_info can not be None!') - gt_bbox = im_info['gt_bbox'] + gt_bbox = label_info['gt_bbox'] im_shape = im_info['image_shape'] im_height = im_shape[0] im_width = im_shape[1] @@ -1520,8 +1529,8 @@ class ArrangeBlazeFace(DetTransform): gt_bbox[i][1] = gt_bbox[i][1] / im_height gt_bbox[i][2] = gt_bbox[i][2] / im_width gt_bbox[i][3] = gt_bbox[i][3] / im_height - outputs = (im, gt_bbox, im_info['gt_class'], - im_info['difficult'], im_info['im_id']) + outputs = (im, gt_bbox, label_info['gt_class'], + label_info['difficult'], im_info['im_id'], im_shape) else: if im_info is None: raise TypeError('Cannot do ArrangeBlazeFace! ' + diff --git a/paddlex/cv/transforms/ops.py b/paddlex/cv/transforms/ops.py index dd517d4..6191875 100644 --- a/paddlex/cv/transforms/ops.py +++ b/paddlex/cv/transforms/ops.py @@ -18,8 +18,9 @@ import numpy as np from PIL import Image, ImageEnhance -def normalize(im, mean, std): - im = im / 255.0 +def normalize(im, mean, std, is_scale=True): + if is_scale: + im = im / 255.0 im -= mean im /= std return im -- GitLab