提交 7a614e76 编写于 作者: S sunyanfang01

fix the bug

上级 8090d633
...@@ -86,6 +86,7 @@ class WIDERFACEDetection(VOCDetection): ...@@ -86,6 +86,7 @@ class WIDERFACEDetection(VOCDetection):
continue continue
else: else:
is_discard = True is_discard = True
print(img_file)
im = cv2.imread(img_file) im = cv2.imread(img_file)
im_w = im.shape[1] im_w = im.shape[1]
im_h = im.shape[0] im_h = im.shape[0]
......
...@@ -22,6 +22,7 @@ import paddlex.utils.logging as logging ...@@ -22,6 +22,7 @@ import paddlex.utils.logging as logging
import paddlex import paddlex
from .base import BaseAPI from .base import BaseAPI
from collections import OrderedDict from collections import OrderedDict
from .utils.detection_eval import eval_results, bbox2out
import copy import copy
class BlazeFace(BaseAPI): class BlazeFace(BaseAPI):
...@@ -74,13 +75,14 @@ class BlazeFace(BaseAPI): ...@@ -74,13 +75,14 @@ class BlazeFace(BaseAPI):
def build_net(self, mode='train'): def build_net(self, mode='train'):
model = paddlex.cv.nets.detection.BlazeFace( model = paddlex.cv.nets.detection.BlazeFace(
backbone=self._get_backbone(self.backbone), backbone=self._get_backbone(self.backbone),
mode=mode,
min_sizes=self.min_sizes, min_sizes=self.min_sizes,
num_classes=self.num_classes, num_classes=self.num_classes,
use_density_prior_box=self.use_density_prior_box, use_density_prior_box=self.use_density_prior_box,
densities=self.densities, densities=self.densities,
nms_threshold=self.nms_iou_threshold, nms_threshold=self.nms_iou_threshold,
nms_topk=self.nms_topk, nms_topk=self.nms_topk,
nms_keep_topk=self.nms_score_threshold, nms_keep_topk=self.nms_keep_topk,
score_threshold=self.nms_score_threshold, score_threshold=self.nms_score_threshold,
fixed_input_shape=self.fixed_input_shape) fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs() inputs = model.generate_inputs()
...@@ -263,6 +265,8 @@ class BlazeFace(BaseAPI): ...@@ -263,6 +265,8 @@ class BlazeFace(BaseAPI):
} }
res_im_id = [d[4] for d in data] res_im_id = [d[4] for d in data]
res['im_id'] = (np.array(res_im_id), []) res['im_id'] = (np.array(res_im_id), [])
res_im_shape = [d[5] for d in data]
res['im_shape'] = (np.array(res_im_shape), [])
if metric == 'VOC': if metric == 'VOC':
res_gt_box = [] res_gt_box = []
res_gt_label = [] res_gt_label = []
......
...@@ -84,6 +84,16 @@ def eval_results(results, ...@@ -84,6 +84,16 @@ def eval_results(results,
return box_ap_stats, eval_details return box_ap_stats, eval_details
def clip_bbox(bbox, im_size=None):
h = 1. if im_size is None else im_size[0]
w = 1. if im_size is None else im_size[1]
xmin = max(min(bbox[0], w), 0.)
ymin = max(min(bbox[1], h), 0.)
xmax = max(min(bbox[2], w), 0.)
ymax = max(min(bbox[3], h), 0.)
return xmin, ymin, xmax, ymax
def proposal_eval(results, coco_gt, outputfile, max_dets=(100, 300, 1000)): def proposal_eval(results, coco_gt, outputfile, max_dets=(100, 300, 1000)):
assert 'proposal' in results[0] assert 'proposal' in results[0]
assert outfile.endswith('.json') assert outfile.endswith('.json')
......
...@@ -24,6 +24,7 @@ from .xception import Xception ...@@ -24,6 +24,7 @@ from .xception import Xception
from .densenet import DenseNet from .densenet import DenseNet
from .shufflenet_v2 import ShuffleNetV2 from .shufflenet_v2 import ShuffleNetV2
from .hrnet import HRNet from .hrnet import HRNet
from .blazenet import BlazeNet
def resnet18(input, num_classes=1000): def resnet18(input, num_classes=1000):
......
...@@ -19,10 +19,6 @@ from __future__ import print_function ...@@ -19,10 +19,6 @@ from __future__ import print_function
from paddle import fluid from paddle import fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
class BlazeNet(object): class BlazeNet(object):
""" """
...@@ -147,7 +143,6 @@ class BlazeNet(object): ...@@ -147,7 +143,6 @@ class BlazeNet(object):
use_pool = not stride == 1 use_pool = not stride == 1
use_double_block = double_channels is not None use_double_block = double_channels is not None
act = 'relu' if use_double_block else None act = 'relu' if use_double_block else None
mixed_precision_enabled = mixed_precision_global_state() is not None
if use_5x5kernel: if use_5x5kernel:
conv_dw = self._conv_norm( conv_dw = self._conv_norm(
...@@ -157,7 +152,7 @@ class BlazeNet(object): ...@@ -157,7 +152,7 @@ class BlazeNet(object):
stride=stride, stride=stride,
padding=2, padding=2,
num_groups=in_channels, num_groups=in_channels,
use_cudnn=mixed_precision_enabled, use_cudnn=True,
name=name + "1_dw") name=name + "1_dw")
else: else:
conv_dw_1 = self._conv_norm( conv_dw_1 = self._conv_norm(
...@@ -167,7 +162,7 @@ class BlazeNet(object): ...@@ -167,7 +162,7 @@ class BlazeNet(object):
stride=1, stride=1,
padding=1, padding=1,
num_groups=in_channels, num_groups=in_channels,
use_cudnn=mixed_precision_enabled, use_cudnn=True,
name=name + "1_dw_1") name=name + "1_dw_1")
conv_dw = self._conv_norm( conv_dw = self._conv_norm(
input=conv_dw_1, input=conv_dw_1,
...@@ -176,7 +171,7 @@ class BlazeNet(object): ...@@ -176,7 +171,7 @@ class BlazeNet(object):
stride=stride, stride=stride,
padding=1, padding=1,
num_groups=in_channels, num_groups=in_channels,
use_cudnn=mixed_precision_enabled, use_cudnn=True,
name=name + "1_dw_2") name=name + "1_dw_2")
conv_pw = self._conv_norm( conv_pw = self._conv_norm(
...@@ -196,7 +191,7 @@ class BlazeNet(object): ...@@ -196,7 +191,7 @@ class BlazeNet(object):
num_filters=out_channels, num_filters=out_channels,
stride=1, stride=1,
padding=2, padding=2,
use_cudnn=mixed_precision_enabled, use_cudnn=True,
name=name + "2_dw") name=name + "2_dw")
else: else:
conv_dw_1 = self._conv_norm( conv_dw_1 = self._conv_norm(
...@@ -206,7 +201,7 @@ class BlazeNet(object): ...@@ -206,7 +201,7 @@ class BlazeNet(object):
stride=1, stride=1,
padding=1, padding=1,
num_groups=out_channels, num_groups=out_channels,
use_cudnn=mixed_precision_enabled, use_cudnn=True,
name=name + "2_dw_1") name=name + "2_dw_1")
conv_dw = self._conv_norm( conv_dw = self._conv_norm(
input=conv_dw_1, input=conv_dw_1,
...@@ -215,7 +210,7 @@ class BlazeNet(object): ...@@ -215,7 +210,7 @@ class BlazeNet(object):
stride=1, stride=1,
padding=1, padding=1,
num_groups=out_channels, num_groups=out_channels,
use_cudnn=mixed_precision_enabled, use_cudnn=True,
name=name + "2_dw_2") name=name + "2_dw_2")
conv_pw = self._conv_norm( conv_pw = self._conv_norm(
......
...@@ -20,6 +20,7 @@ from collections import OrderedDict ...@@ -20,6 +20,7 @@ from collections import OrderedDict
class BlazeFace: class BlazeFace:
def __init__(self, def __init__(self,
backbone, backbone,
mode='train',
min_sizes=[[16., 24.], [32., 48., 64., 80., 96., 128.]], min_sizes=[[16., 24.], [32., 48., 64., 80., 96., 128.]],
max_sizes=None, max_sizes=None,
steps=[8., 16.], steps=[8., 16.],
...@@ -33,8 +34,8 @@ class BlazeFace: ...@@ -33,8 +34,8 @@ class BlazeFace:
nms_eta=1.0, nms_eta=1.0,
fixed_input_shape=None): fixed_input_shape=None):
self.backbone = backbone self.backbone = backbone
self.mode=mode
self.num_classes = num_classes self.num_classes = num_classes
self.output_decoder = output_decoder
self.min_sizes = min_sizes self.min_sizes = min_sizes
self.max_sizes = max_sizes self.max_sizes = max_sizes
self.steps = steps self.steps = steps
...@@ -130,24 +131,24 @@ class BlazeFace: ...@@ -130,24 +131,24 @@ class BlazeFace:
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image') dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['gt_box'] = fluid.data( inputs['gt_bbox'] = fluid.data(
dtype='float32', shape=[None, None, 4], lod_level=1, name='gt_box') dtype='float32', shape=[None, 4], lod_level=1, name='gt_bbox')
inputs['gt_label'] = fluid.data( inputs['gt_label'] = fluid.data(
dtype='int32', shape=[None, None], lod_level=1, name='gt_label') dtype='int32', shape=[None, 1], lod_level=1, name='gt_label')
inputs['im_size'] = fluid.data(
dtype='int32', shape=[None, 2], name='im_size')
elif self.mode == 'eval': elif self.mode == 'eval':
inputs['gt_box'] = fluid.data( inputs['gt_bbox'] = fluid.data(
dtype='float32', shape=[None, None, 4], lod_level=1, name='gt_box') dtype='float32', shape=[None, 4], lod_level=1, name='gt_bbox')
inputs['gt_label'] = fluid.data( inputs['gt_label'] = fluid.data(
dtype='int32', shape=[None, None], lod_level=1, name='gt_label') dtype='int32', shape=[None, 1], lod_level=1, name='gt_label')
inputs['is_difficult'] = fluid.data( inputs['is_difficult'] = fluid.data(
dtype='int32', shape=[None, 1], lod_level=1, name='is_difficult') dtype='int32', shape=[None, 1], lod_level=1, name='is_difficult')
inputs['im_id'] = fluid.data( inputs['im_id'] = fluid.data(
dtype='int32', shape=[None, 1], name='im_id') dtype='int32', shape=[None, 1], name='im_id')
inputs['im_shape'] = fluid.data(
dtype='int32', shape=[None, 2], name='im_shape')
elif self.mode == 'test': elif self.mode == 'test':
inputs['im_size'] = fluid.data( inputs['im_shape'] = fluid.data(
dtype='int32', shape=[None, 2], name='im_size') dtype='int32', shape=[None, 2], name='im_shape')
return inputs return inputs
...@@ -156,22 +157,13 @@ class BlazeFace: ...@@ -156,22 +157,13 @@ class BlazeFace:
if self.mode == 'train': if self.mode == 'train':
gt_bbox = inputs['gt_bbox'] gt_bbox = inputs['gt_bbox']
gt_label = inputs['gt_label'] gt_label = inputs['gt_label']
im_size = inputs['im_size']
num_boxes = fluid.layers.shape(gt_box)[1]
im_size_wh = fluid.layers.reverse(im_size, axis=1)
whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
whwh = fluid.layers.unsqueeze(whwh, axes=[1])
whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
whwh = fluid.layers.cast(whwh, dtype='float32')
whwh.stop_gradient = True
normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
body_feats = self.backbone(image) body_feats = self.backbone(image)
locs, confs, box, box_var = self._multi_box_head( locs, confs, box, box_var = self._multi_box_head(
inputs=body_feats, inputs=body_feats,
image=image, image=image,
num_classes=self.num_classes, num_classes=self.num_classes,
use_density_prior_box=self.use_density_prior_box) use_density_prior_box=self.use_density_prior_box)
if mode == 'train': if self.mode == 'train':
loss = fluid.layers.ssd_loss( loss = fluid.layers.ssd_loss(
locs, locs,
confs, confs,
...@@ -192,7 +184,7 @@ class BlazeFace: ...@@ -192,7 +184,7 @@ class BlazeFace:
box_var, box_var,
background_label=self.background_label, background_label=self.background_label,
nms_threshold=self.nms_threshold, nms_threshold=self.nms_threshold,
nms_top_k=self.nms_keep_topk, nms_top_k=self.nms_topk,
keep_top_k=self.nms_keep_topk, keep_top_k=self.nms_keep_topk,
score_threshold=self.score_threshold, score_threshold=self.score_threshold,
nms_eta=self.nms_eta) nms_eta=self.nms_eta)
......
...@@ -459,4 +459,56 @@ def generate_sample_bbox_square(sampler, image_width, image_height): ...@@ -459,4 +459,56 @@ def generate_sample_bbox_square(sampler, image_width, image_height):
xmax = xmin + bbox_width xmax = xmin + bbox_width
ymax = ymin + bbox_height ymax = ymin + bbox_height
sampled_bbox = [xmin, ymin, xmax, ymax] sampled_bbox = [xmin, ymin, xmax, ymax]
return sampled_bbox return sampled_bbox
\ No newline at end of file
def bbox_coverage(bbox1, bbox2):
inter_box = intersect_bbox(bbox1, bbox2)
intersect_size = bbox_area(inter_box)
if intersect_size > 0:
bbox1_size = bbox_area(bbox1)
return intersect_size / bbox1_size
else:
return 0.
def meet_emit_constraint(src_bbox, sample_bbox):
center_x = (src_bbox[2] + src_bbox[0]) / 2
center_y = (src_bbox[3] + src_bbox[1]) / 2
if center_x >= sample_bbox[0] and \
center_x <= sample_bbox[2] and \
center_y >= sample_bbox[1] and \
center_y <= sample_bbox[3]:
return True
return False
def is_overlap(object_bbox, sample_bbox):
if object_bbox[0] >= sample_bbox[2] or \
object_bbox[2] <= sample_bbox[0] or \
object_bbox[1] >= sample_bbox[3] or \
object_bbox[3] <= sample_bbox[1]:
return False
else:
return True
def intersect_bbox(bbox1, bbox2):
if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \
bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]:
intersection_box = [0.0, 0.0, 0.0, 0.0]
else:
intersection_box = [
max(bbox1[0], bbox2[0]), max(bbox1[1], bbox2[1]),
min(bbox1[2], bbox2[2]), min(bbox1[3], bbox2[3])
]
return intersection_box
def clip_bbox(src_bbox):
src_bbox[0] = max(min(src_bbox[0], 1.0), 0.0)
src_bbox[1] = max(min(src_bbox[1], 1.0), 0.0)
src_bbox[2] = max(min(src_bbox[2], 1.0), 0.0)
src_bbox[3] = max(min(src_bbox[3], 1.0), 0.0)
return src_bbox
\ No newline at end of file
...@@ -1110,13 +1110,13 @@ class CropImageWithDataAchorSampling(DetTransform): ...@@ -1110,13 +1110,13 @@ class CropImageWithDataAchorSampling(DetTransform):
gt_bbox = label_info['gt_bbox'] gt_bbox = label_info['gt_bbox']
gt_bbox_tmp = gt_bbox.copy() gt_bbox_tmp = gt_bbox.copy()
for i in range(gt_bbox_tmp.shape[0]): for i in range(gt_bbox_tmp.shape[0]):
gt_bbox_tmp[i][0] = gt_bbox[i][0] / im_width gt_bbox_tmp[i][0] = gt_bbox[i][0] / image_width
gt_bbox_tmp[i][1] = gt_bbox[i][1] / im_height gt_bbox_tmp[i][1] = gt_bbox[i][1] / image_height
gt_bbox_tmp[i][2] = gt_bbox[i][2] / im_width gt_bbox_tmp[i][2] = gt_bbox[i][2] / image_width
gt_bbox_tmp[i][3] = gt_bbox[i][3] / im_height gt_bbox_tmp[i][3] = gt_bbox[i][3] / image_height
gt_class = label_info['gt_class'] gt_class = label_info['gt_class']
gt_score = None gt_score = None
if 'gt_score' in sample: if 'gt_score' in label_info:
gt_score = label_info['gt_score'] gt_score = label_info['gt_score']
sampled_bbox = [] sampled_bbox = []
gt_bbox_tmp = gt_bbox_tmp.tolist() gt_bbox_tmp = gt_bbox_tmp.tolist()
...@@ -1505,13 +1505,22 @@ class ArrangeBlazeFace(DetTransform): ...@@ -1505,13 +1505,22 @@ class ArrangeBlazeFace(DetTransform):
'Becasuse the im_info and label_info can not be None!') 'Becasuse the im_info and label_info can not be None!')
if len(label_info['gt_bbox']) != len(label_info['gt_class']): if len(label_info['gt_bbox']) != len(label_info['gt_class']):
raise ValueError("gt num mismatch: bbox and class.") raise ValueError("gt num mismatch: bbox and class.")
outputs = (im, label_info['gt_bbox'], label_info['gt_class'], im_info['image_shape']) gt_bbox = label_info['gt_bbox']
im_shape = im_info['image_shape']
im_height = im_shape[0]
im_width = im_shape[1]
for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] / im_width
gt_bbox[i][1] = gt_bbox[i][1] / im_height
gt_bbox[i][2] = gt_bbox[i][2] / im_width
gt_bbox[i][3] = gt_bbox[i][3] / im_height
outputs = (im, gt_bbox, label_info['gt_class'])
elif self.mode == 'eval': elif self.mode == 'eval':
if im_info is None : if im_info is None :
raise TypeError( raise TypeError(
'Cannot do ArrangeBlazeFace! ' + 'Cannot do ArrangeBlazeFace! ' +
'Becasuse the im_info can not be None!') 'Becasuse the im_info can not be None!')
gt_bbox = im_info['gt_bbox'] gt_bbox = label_info['gt_bbox']
im_shape = im_info['image_shape'] im_shape = im_info['image_shape']
im_height = im_shape[0] im_height = im_shape[0]
im_width = im_shape[1] im_width = im_shape[1]
...@@ -1520,8 +1529,8 @@ class ArrangeBlazeFace(DetTransform): ...@@ -1520,8 +1529,8 @@ class ArrangeBlazeFace(DetTransform):
gt_bbox[i][1] = gt_bbox[i][1] / im_height gt_bbox[i][1] = gt_bbox[i][1] / im_height
gt_bbox[i][2] = gt_bbox[i][2] / im_width gt_bbox[i][2] = gt_bbox[i][2] / im_width
gt_bbox[i][3] = gt_bbox[i][3] / im_height gt_bbox[i][3] = gt_bbox[i][3] / im_height
outputs = (im, gt_bbox, im_info['gt_class'], outputs = (im, gt_bbox, label_info['gt_class'],
im_info['difficult'], im_info['im_id']) label_info['difficult'], im_info['im_id'], im_shape)
else: else:
if im_info is None: if im_info is None:
raise TypeError('Cannot do ArrangeBlazeFace! ' + raise TypeError('Cannot do ArrangeBlazeFace! ' +
......
...@@ -18,8 +18,9 @@ import numpy as np ...@@ -18,8 +18,9 @@ import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
def normalize(im, mean, std): def normalize(im, mean, std, is_scale=True):
im = im / 255.0 if is_scale:
im = im / 255.0
im -= mean im -= mean
im /= std im /= std
return im return im
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册