未验证 提交 01d57c6a 编写于 作者: G Guanghua Yu 提交者: GitHub

fix RCNN dygraph to static (#2184)

* fix RCNN dygraph to static
上级 9a4fae6d
...@@ -84,16 +84,8 @@ class Detector(object): ...@@ -84,16 +84,8 @@ class Detector(object):
np_boxes[:, 3] *= w np_boxes[:, 3] *= w
np_boxes[:, 4] *= h np_boxes[:, 4] *= h
np_boxes[:, 5] *= w np_boxes[:, 5] *= w
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
for box in np_boxes:
print('class_id:{:d}, confidence:{:.4f},'
'left_top:[{:.2f},{:.2f}],'
' right_bottom:[{:.2f},{:.2f}]'.format(
int(box[0]), box[1], box[2], box[3], box[4], box[5]))
results['boxes'] = np_boxes results['boxes'] = np_boxes
if np_masks is not None: if np_masks is not None:
np_masks = np_masks[expect_boxes, :, :, :]
results['masks'] = np_masks results['masks'] = np_masks
return results return results
...@@ -111,7 +103,7 @@ class Detector(object): ...@@ -111,7 +103,7 @@ class Detector(object):
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max] matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray: MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution] shape: [N, im_h, im_w]
''' '''
inputs = self.preprocess(image) inputs = self.preprocess(image)
np_boxes, np_masks = None, None np_boxes, np_masks = None, None
...@@ -125,7 +117,7 @@ class Detector(object): ...@@ -125,7 +117,7 @@ class Detector(object):
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0]) boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu() np_boxes = boxes_tensor.copy_to_cpu()
if self.pred_config.mask_resolution is not None: if self.pred_config.mask:
masks_tensor = self.predictor.get_output_handle(output_names[2]) masks_tensor = self.predictor.get_output_handle(output_names[2])
np_masks = masks_tensor.copy_to_cpu() np_masks = masks_tensor.copy_to_cpu()
...@@ -135,14 +127,7 @@ class Detector(object): ...@@ -135,14 +127,7 @@ class Detector(object):
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0]) boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu() np_boxes = boxes_tensor.copy_to_cpu()
score_tensor = self.predictor.get_output_handle(output_names[3]) if self.pred_config.mask:
np_score = score_tensor.copy_to_cpu()
label_tensor = self.predictor.get_output_handle(output_names[2])
np_label = label_tensor.copy_to_cpu()
np_boxes = np.concatenate(
[np_label[:, np.newaxis], np_score[:, np.newaxis], np_boxes],
axis=-1)
if self.pred_config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_handle(output_names[2]) masks_tensor = self.predictor.get_output_handle(output_names[2])
np_masks = masks_tensor.copy_to_cpu() np_masks = masks_tensor.copy_to_cpu()
t2 = time.time() t2 = time.time()
...@@ -196,10 +181,9 @@ class DetectorSOLOv2(Detector): ...@@ -196,10 +181,9 @@ class DetectorSOLOv2(Detector):
image (str/np.ndarray): path of image/ np.ndarray read by cv2 image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score threshold (float): threshold of predicted box' score
Returns: Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, results (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
matix element:[class, score, x_min, y_min, x_max, y_max] 'cate_label': label of segm, shape:[N]
MaskRCNN's results include 'masks': np.ndarray: 'cate_score': confidence score of segm, shape:[N]
shape:[N, class_num, mask_resolution, mask_resolution]
''' '''
inputs = self.preprocess(image) inputs = self.preprocess(image)
np_label, np_score, np_segms = None, None, None np_label, np_score, np_segms = None, None, None
...@@ -273,9 +257,9 @@ class PredictConfig(): ...@@ -273,9 +257,9 @@ class PredictConfig():
self.preprocess_infos = yml_conf['Preprocess'] self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size'] self.min_subgraph_size = yml_conf['min_subgraph_size']
self.labels = yml_conf['label_list'] self.labels = yml_conf['label_list']
self.mask_resolution = None self.mask = False
if 'mask_resolution' in yml_conf: if 'mask' in yml_conf:
self.mask_resolution = yml_conf['mask_resolution'] self.mask = yml_conf['mask']
self.input_shape = yml_conf['image_shape'] self.input_shape = yml_conf['image_shape']
self.print_config() self.print_config()
...@@ -355,19 +339,9 @@ def load_predictor(model_dir, ...@@ -355,19 +339,9 @@ def load_predictor(model_dir,
return predictor return predictor
def visualize(image_file, def visualize(image_file, results, labels, output_dir='output/', threshold=0.5):
results,
labels,
mask_resolution=14,
output_dir='output/',
threshold=0.5):
# visualize the predict result # visualize the predict result
im = visualize_box_mask( im = visualize_box_mask(image_file, results, labels, threshold=threshold)
image_file,
results,
labels,
mask_resolution=mask_resolution,
threshold=threshold)
img_name = os.path.split(image_file)[-1] img_name = os.path.split(image_file)[-1]
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
...@@ -397,7 +371,6 @@ def predict_image(detector): ...@@ -397,7 +371,6 @@ def predict_image(detector):
FLAGS.image_file, FLAGS.image_file,
results, results,
detector.pred_config.labels, detector.pred_config.labels,
mask_resolution=detector.pred_config.mask_resolution,
output_dir=FLAGS.output_dir, output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold) threshold=FLAGS.threshold)
...@@ -431,7 +404,6 @@ def predict_video(detector, camera_id): ...@@ -431,7 +404,6 @@ def predict_video(detector, camera_id):
frame, frame,
results, results,
detector.pred_config.labels, detector.pred_config.labels,
mask_resolution=detector.pred_config.mask_resolution,
threshold=FLAGS.threshold) threshold=FLAGS.threshold)
im = np.array(im) im = np.array(im)
writer.write(im) writer.write(im)
......
...@@ -21,16 +21,15 @@ from PIL import Image, ImageDraw ...@@ -21,16 +21,15 @@ from PIL import Image, ImageDraw
from scipy import ndimage from scipy import ndimage
def visualize_box_mask(im, results, labels, mask_resolution=14, threshold=0.5): def visualize_box_mask(im, results, labels, threshold=0.5):
""" """
Args: Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2 im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max] matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray: MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution] shape:[N, im_h, im_w]
labels (list): labels:['class1', ..., 'classn'] labels (list): labels:['class1', ..., 'classn']
mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution]
threshold (float): Threshold of score. threshold (float): Threshold of score.
Returns: Returns:
im (PIL.Image.Image): visualized image im (PIL.Image.Image): visualized image
...@@ -41,13 +40,9 @@ def visualize_box_mask(im, results, labels, mask_resolution=14, threshold=0.5): ...@@ -41,13 +40,9 @@ def visualize_box_mask(im, results, labels, mask_resolution=14, threshold=0.5):
im = Image.fromarray(im) im = Image.fromarray(im)
if 'masks' in results and 'boxes' in results: if 'masks' in results and 'boxes' in results:
im = draw_mask( im = draw_mask(
im, im, results['boxes'], results['masks'], labels, threshold=threshold)
results['boxes'],
results['masks'],
labels,
resolution=mask_resolution)
if 'boxes' in results: if 'boxes' in results:
im = draw_box(im, results['boxes'], labels) im = draw_box(im, results['boxes'], labels, threshold=threshold)
if 'segm' in results: if 'segm' in results:
im = draw_segm( im = draw_segm(
im, im,
...@@ -80,91 +75,49 @@ def get_color_map_list(num_classes): ...@@ -80,91 +75,49 @@ def get_color_map_list(num_classes):
return color_map return color_map
def expand_boxes(boxes, scale=0.0): def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
"""
Args:
boxes (np.ndarray): shape:[N,4], N:number of box,
matix element:[x_min, y_min, x_max, y_max]
scale (float): scale of boxes
Returns:
boxes_exp (np.ndarray): expanded boxes
"""
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half *= scale
h_half *= scale
boxes_exp = np.zeros(boxes.shape)
boxes_exp[:, 0] = x_c - w_half
boxes_exp[:, 2] = x_c + w_half
boxes_exp[:, 1] = y_c - h_half
boxes_exp[:, 3] = y_c + h_half
return boxes_exp
def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
""" """
Args: Args:
im (PIL.Image.Image): PIL image im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box, np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max] matix element:[class, score, x_min, y_min, x_max, y_max]
np_masks (np.ndarray): shape:[N, class_num, resolution, resolution] np_masks (np.ndarray): shape:[N, im_h, im_w]
labels (list): labels:['class1', ..., 'classn'] labels (list): labels:['class1', ..., 'classn']
resolution (int): shape of a mask is:[resolution, resolution]
threshold (float): threshold of mask threshold (float): threshold of mask
Returns: Returns:
im (PIL.Image.Image): visualized image im (PIL.Image.Image): visualized image
""" """
color_list = get_color_map_list(len(labels)) color_list = get_color_map_list(len(labels))
scale = (resolution + 2.0) / resolution
im_w, im_h = im.size
w_ratio = 0.4 w_ratio = 0.4
alpha = 0.7 alpha = 0.7
im = np.array(im).astype('float32') im = np.array(im).astype('float32')
rects = np_boxes[:, 2:]
expand_rects = expand_boxes(rects, scale)
expand_rects = expand_rects.astype(np.int32)
clsid_scores = np_boxes[:, 0:2]
padded_mask = np.zeros((resolution + 2, resolution + 2), dtype=np.float32)
clsid2color = {} clsid2color = {}
for idx in range(len(np_boxes)): expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
clsid, score = clsid_scores[idx].tolist() np_boxes = np_boxes[expect_boxes, :]
clsid = int(clsid) np_masks = np_masks[expect_boxes, :, :]
xmin, ymin, xmax, ymax = expand_rects[idx].tolist() for i in range(len(np_masks)):
w = xmax - xmin + 1 clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
h = ymax - ymin + 1 mask = np_masks[i]
w = np.maximum(w, 1)
h = np.maximum(h, 1)
padded_mask[1:-1, 1:-1] = np_masks[idx, int(clsid), :, :]
resized_mask = cv2.resize(padded_mask, (w, h))
resized_mask = np.array(resized_mask > threshold, dtype=np.uint8)
x0 = min(max(xmin, 0), im_w)
x1 = min(max(xmax + 1, 0), im_w)
y0 = min(max(ymin, 0), im_h)
y1 = min(max(ymax + 1, 0), im_h)
im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
x0 - xmin):(x1 - xmin)]
if clsid not in clsid2color: if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid] clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid] color_mask = clsid2color[clsid]
for c in range(3): for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(im_mask) idx = np.nonzero(mask)
color_mask = np.array(color_mask) color_mask = np.array(color_mask)
im[idx[0], idx[1], :] *= 1.0 - alpha im[idx[0], idx[1], :] *= 1.0 - alpha
im[idx[0], idx[1], :] += alpha * color_mask im[idx[0], idx[1], :] += alpha * color_mask
return Image.fromarray(im.astype('uint8')) return Image.fromarray(im.astype('uint8'))
def draw_box(im, np_boxes, labels): def draw_box(im, np_boxes, labels, threshold=0.5):
""" """
Args: Args:
im (PIL.Image.Image): PIL image im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box, np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max] matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn'] labels (list): labels:['class1', ..., 'classn']
threshold (float): threshold of box
Returns: Returns:
im (PIL.Image.Image): visualized image im (PIL.Image.Image): visualized image
""" """
...@@ -172,10 +125,15 @@ def draw_box(im, np_boxes, labels): ...@@ -172,10 +125,15 @@ def draw_box(im, np_boxes, labels):
draw = ImageDraw.Draw(im) draw = ImageDraw.Draw(im)
clsid2color = {} clsid2color = {}
color_list = get_color_map_list(len(labels)) color_list = get_color_map_list(len(labels))
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
for dt in np_boxes: for dt in np_boxes:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1] clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
xmin, ymin, xmax, ymax = bbox xmin, ymin, xmax, ymax = bbox
print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'
'right_bottom:[{:.2f},{:.2f}]'.format(
int(clsid), score, xmin, ymin, xmax, ymax))
w = xmax - xmin w = xmax - xmin
h = ymax - ymin h = ymax - ymin
if clsid not in clsid2color: if clsid not in clsid2color:
......
...@@ -98,9 +98,8 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -98,9 +98,8 @@ def _dump_infer_config(config, path, image_shape, model):
'Architecture: {} is not supported for exporting model now'.format( 'Architecture: {} is not supported for exporting model now'.format(
infer_arch)) infer_arch))
os._exit(0) os._exit(0)
if 'mask_post_process' in model.__dict__ and model.__dict__[ if 'Mask' in infer_arch:
'mask_post_process']: infer_cfg['mask'] = True
infer_cfg['mask_resolution'] = model.mask_post_process.mask_resolution
infer_cfg['Preprocess'], infer_cfg[ infer_cfg['Preprocess'], infer_cfg[
'label_list'], image_shape = _parse_reader( 'label_list'], image_shape = _parse_reader(
config['TestReader'], config['TestDataset'], config['metric'], config['TestReader'], config['TestDataset'], config['metric'],
......
...@@ -30,7 +30,7 @@ def get_infer_results(outs, catid, bias=0): ...@@ -30,7 +30,7 @@ def get_infer_results(outs, catid, bias=0):
The output format is dictionary containing bbox or mask result. The output format is dictionary containing bbox or mask result.
For example, bbox result is a list and each element contains For example, bbox result is a list and each element contains
image_id, category_id, bbox and score. image_id, category_id, bbox and score.
""" """
if outs is None or len(outs) == 0: if outs is None or len(outs) == 0:
raise ValueError( raise ValueError(
...@@ -42,19 +42,12 @@ def get_infer_results(outs, catid, bias=0): ...@@ -42,19 +42,12 @@ def get_infer_results(outs, catid, bias=0):
infer_res = {} infer_res = {}
if 'bbox' in outs: if 'bbox' in outs:
infer_res['bbox'] = get_det_res( infer_res['bbox'] = get_det_res(
outs['bbox'], outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias)
outs['score'],
outs['label'],
outs['bbox_num'],
im_id,
catid,
bias=bias)
if 'mask' in outs: if 'mask' in outs:
# mask post process # mask post process
infer_res['mask'] = get_seg_res(outs['mask'], outs['score'], infer_res['mask'] = get_seg_res(outs['mask'], outs['bbox'],
outs['label'], outs['bbox_num'], im_id, outs['bbox_num'], im_id, catid)
catid)
if 'segm' in outs: if 'segm' in outs:
infer_res['segm'] = get_solov2_segm_res(outs, im_id, catid) infer_res['segm'] = get_solov2_segm_res(outs, im_id, catid)
......
...@@ -99,13 +99,5 @@ class FasterRCNN(BaseArch): ...@@ -99,13 +99,5 @@ class FasterRCNN(BaseArch):
def get_pred(self): def get_pred(self):
bbox_pred, bbox_num = self._forward() bbox_pred, bbox_num = self._forward()
label = bbox_pred[:, 0] output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
score = bbox_pred[:, 1]
bbox = bbox_pred[:, 2:]
output = {
'bbox': bbox,
'score': score,
'label': label,
'bbox_num': bbox_num
}
return output return output
...@@ -91,13 +91,5 @@ class FCOS(BaseArch): ...@@ -91,13 +91,5 @@ class FCOS(BaseArch):
def get_pred(self): def get_pred(self):
bboxes, bbox_num = self._forward() bboxes, bbox_num = self._forward()
label = bboxes[:, 0] output = {'bbox': bboxes, 'bbox_num': bbox_num}
score = bboxes[:, 1]
bbox = bboxes[:, 2:]
output = {
'bbox': bbox,
'score': score,
'label': label,
'bbox_num': bbox_num
}
return output return output
...@@ -124,14 +124,5 @@ class MaskRCNN(BaseArch): ...@@ -124,14 +124,5 @@ class MaskRCNN(BaseArch):
def get_pred(self): def get_pred(self):
bbox_pred, bbox_num, mask_pred = self._forward() bbox_pred, bbox_num, mask_pred = self._forward()
label = bbox_pred[:, 0] output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred}
score = bbox_pred[:, 1]
bbox = bbox_pred[:, 2:]
output = {
'label': label,
'score': score,
'bbox': bbox,
'bbox_num': bbox_num,
'mask': mask_pred,
}
return output return output
...@@ -91,13 +91,8 @@ class TTFNet(BaseArch): ...@@ -91,13 +91,8 @@ class TTFNet(BaseArch):
def get_pred(self): def get_pred(self):
bbox_pred, bbox_num = self._forward() bbox_pred, bbox_num = self._forward()
label = bbox_pred[:, 0]
score = bbox_pred[:, 1]
bbox = bbox_pred[:, 2:]
output = { output = {
"bbox": bbox, "bbox": bbox_pred,
'score': score,
'label': label,
"bbox_num": bbox_num, "bbox_num": bbox_num,
} }
return output return output
...@@ -61,13 +61,5 @@ class YOLOv3(BaseArch): ...@@ -61,13 +61,5 @@ class YOLOv3(BaseArch):
def get_pred(self): def get_pred(self):
bbox_pred, bbox_num = self._forward() bbox_pred, bbox_num = self._forward()
label = bbox_pred[:, 0] output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
score = bbox_pred[:, 1]
bbox = bbox_pred[:, 2:]
output = {
'bbox': bbox,
'score': score,
'label': label,
'bbox_num': bbox_num
}
return output return output
...@@ -39,8 +39,6 @@ def bbox2delta(src_boxes, tgt_boxes, weights): ...@@ -39,8 +39,6 @@ def bbox2delta(src_boxes, tgt_boxes, weights):
def delta2bbox(deltas, boxes, weights): def delta2bbox(deltas, boxes, weights):
clip_scale = math.log(1000.0 / 16) clip_scale = math.log(1000.0 / 16)
if boxes.shape[0] == 0:
return paddle.zeros((0, deltas.shape[1]), dtype='float32')
widths = boxes[:, 2] - boxes[:, 0] widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1] heights = boxes[:, 3] - boxes[:, 1]
...@@ -61,12 +59,13 @@ def delta2bbox(deltas, boxes, weights): ...@@ -61,12 +59,13 @@ def delta2bbox(deltas, boxes, weights):
pred_w = paddle.exp(dw) * widths.unsqueeze(1) pred_w = paddle.exp(dw) * widths.unsqueeze(1)
pred_h = paddle.exp(dh) * heights.unsqueeze(1) pred_h = paddle.exp(dh) * heights.unsqueeze(1)
pred_boxes = paddle.zeros_like(deltas) pred_boxes = []
pred_boxes.append(pred_ctr_x - 0.5 * pred_w)
pred_boxes.append(pred_ctr_y - 0.5 * pred_h)
pred_boxes.append(pred_ctr_x + 0.5 * pred_w)
pred_boxes.append(pred_ctr_y + 0.5 * pred_h)
pred_boxes = paddle.stack(pred_boxes, axis=-1)
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h
return pred_boxes return pred_boxes
......
...@@ -141,8 +141,7 @@ class BBoxHead(nn.Layer): ...@@ -141,8 +141,7 @@ class BBoxHead(nn.Layer):
rois_feat = self.roi_extractor(body_feats, rois, rois_num) rois_feat = self.roi_extractor(body_feats, rois, rois_num)
bbox_feat = self.head(rois_feat) bbox_feat = self.head(rois_feat)
#if self.with_pool: if self.with_pool:
if len(bbox_feat.shape) > 2 and bbox_feat.shape[-1] > 1:
feat = F.adaptive_avg_pool2d(bbox_feat, output_size=1) feat = F.adaptive_avg_pool2d(bbox_feat, output_size=1)
feat = paddle.squeeze(feat, axis=[2, 3]) feat = paddle.squeeze(feat, axis=[2, 3])
else: else:
......
...@@ -182,11 +182,12 @@ class MaskHead(nn.Layer): ...@@ -182,11 +182,12 @@ class MaskHead(nn.Layer):
mask_out = F.sigmoid(mask_logit) mask_out = F.sigmoid(mask_logit)
else: else:
num_masks = mask_logit.shape[0] num_masks = mask_logit.shape[0]
pred_masks = paddle.split(mask_logit, num_masks)
mask_out = [] mask_out = []
# TODO: need to optimize gather # TODO: need to optimize gather
for i, pred_mask in enumerate(pred_masks): for i in range(mask_logit.shape[0]):
mask = paddle.gather(pred_mask, labels[i], axis=1) pred_masks = paddle.unsqueeze(
mask_logit[i, :, :, :], axis=0)
mask = paddle.gather(pred_masks, labels[i], axis=1)
mask_out.append(mask) mask_out.append(mask)
mask_out = F.sigmoid(paddle.concat(mask_out)) mask_out = F.sigmoid(paddle.concat(mask_out))
return mask_out return mask_out
......
...@@ -316,14 +316,12 @@ class RCNNBox(object): ...@@ -316,14 +316,12 @@ class RCNNBox(object):
# [N, C*4] # [N, C*4]
bbox = paddle.concat(roi) bbox = paddle.concat(roi)
bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var) if bbox.shape[0] == 0:
bbox = paddle.zeros([0, bbox_pred.shape[1]], dtype='float32')
else:
bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
scores = cls_prob[:, :-1] scores = cls_prob[:, :-1]
# [N*C, 4]
bbox_num_class = bbox.shape[1] // 4
bbox = paddle.reshape(bbox, [-1, bbox_num_class, 4])
origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1) origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1) origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
zeros = paddle.zeros_like(origin_h) zeros = paddle.zeros_like(origin_h)
......
...@@ -54,8 +54,6 @@ class BBoxPostProcess(object): ...@@ -54,8 +54,6 @@ class BBoxPostProcess(object):
including labels, scores and bboxes. The size of including labels, scores and bboxes. The size of
bboxes are corresponding to the original image. bboxes are corresponding to the original image.
""" """
if bboxes.shape[0] == 0:
return paddle.zeros(shape=[1, 6])
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
...@@ -65,9 +63,12 @@ class BBoxPostProcess(object): ...@@ -65,9 +63,12 @@ class BBoxPostProcess(object):
for i in range(bbox_num.shape[0]): for i in range(bbox_num.shape[0]):
expand_shape = paddle.expand(origin_shape[i:i + 1, :], expand_shape = paddle.expand(origin_shape[i:i + 1, :],
[bbox_num[i], 2]) [bbox_num[i], 2])
scale_y, scale_x = scale_factor[i] scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
scale = paddle.concat([scale_x, scale_y, scale_x, scale_y]) scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
expand_scale = paddle.expand(scale, [bbox_num[i], 4]) expand_scale = paddle.expand(scale, [bbox_num[i], 4])
# TODO: Because paddle.expand transform error when dygraph
# to static, use reshape to avoid mistakes.
expand_scale = paddle.reshape(expand_scale, [bbox_num[i], 4])
origin_shape_list.append(expand_shape) origin_shape_list.append(expand_shape)
scale_factor_list.append(expand_scale) scale_factor_list.append(expand_scale)
...@@ -121,6 +122,10 @@ class MaskPostProcess(object): ...@@ -121,6 +122,10 @@ class MaskPostProcess(object):
gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]]) gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]])
gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]]) gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]])
# TODO: Because paddle.expand transform error when dygraph
# to static, use reshape to avoid mistakes.
gx = paddle.reshape(gx, [N, img_y.shape[1], img_x.shape[2]])
gy = paddle.reshape(gy, [N, img_y.shape[1], img_x.shape[2]])
grid = paddle.stack([gx, gy], axis=3) grid = paddle.stack([gx, gy], axis=3)
img_masks = F.grid_sample(masks, grid, align_corners=False) img_masks = F.grid_sample(masks, grid, align_corners=False)
return img_masks[:, 0] return img_masks[:, 0]
...@@ -129,19 +134,24 @@ class MaskPostProcess(object): ...@@ -129,19 +134,24 @@ class MaskPostProcess(object):
""" """
Paste the mask prediction to the original image. Paste the mask prediction to the original image.
""" """
assert bboxes.shape[0] > 0, 'There is no detection output'
num_mask = mask_out.shape[0] num_mask = mask_out.shape[0]
# TODO: support bs > 1 origin_shape = paddle.cast(origin_shape, 'int32')
# TODO: support bs > 1 and mask output dtype is bool
pred_result = paddle.zeros( pred_result = paddle.zeros(
[num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='bool') [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
if bboxes.shape[0] == 0:
return pred_result
# TODO: optimize chunk paste # TODO: optimize chunk paste
pred_result = []
for i in range(bboxes.shape[0]): for i in range(bboxes.shape[0]):
im_h, im_w = origin_shape[i] im_h, im_w = origin_shape[i][0], origin_shape[i][1]
pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h, pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h,
im_w) im_w)
pred_mask = pred_mask >= self.binary_thresh pred_mask = pred_mask >= self.binary_thresh
pred_result[i] = pred_mask pred_mask = paddle.cast(pred_mask, 'int32')
pred_result.append(pred_mask)
pred_result = paddle.concat(pred_result)
return pred_result return pred_result
......
...@@ -24,7 +24,7 @@ from .. import ops ...@@ -24,7 +24,7 @@ from .. import ops
@register @register
class AnchorGenerator(object): class AnchorGenerator(nn.Layer):
def __init__(self, def __init__(self,
anchor_sizes=[32, 64, 128, 256, 512], anchor_sizes=[32, 64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0], aspect_ratios=[0.5, 1.0, 2.0],
...@@ -64,17 +64,21 @@ class AnchorGenerator(object): ...@@ -64,17 +64,21 @@ class AnchorGenerator(object):
self.generate_cell_anchors(s, a) self.generate_cell_anchors(s, a)
for s, a in zip(sizes, aspect_ratios) for s, a in zip(sizes, aspect_ratios)
] ]
[
self.register_buffer(
t.name, t, persistable=False) for t in cell_anchors
]
return cell_anchors return cell_anchors
def _create_grid_offsets(self, size, stride, offset): def _create_grid_offsets(self, size, stride, offset):
grid_height, grid_width = size grid_height, grid_width = size[0], size[1]
shifts_x = paddle.arange( shifts_x = paddle.arange(
offset * stride, grid_width * stride, step=stride, dtype='float32') offset * stride, grid_width * stride, step=stride, dtype='float32')
shifts_y = paddle.arange( shifts_y = paddle.arange(
offset * stride, grid_height * stride, step=stride, dtype='float32') offset * stride, grid_height * stride, step=stride, dtype='float32')
shift_y, shift_x = paddle.meshgrid(shifts_y, shifts_x) shift_y, shift_x = paddle.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape([-1]) shift_x = paddle.reshape(shift_x, [-1])
shift_y = shift_y.reshape([-1]) shift_y = paddle.reshape(shift_y, [-1])
return shift_x, shift_y return shift_x, shift_y
def _grid_anchors(self, grid_sizes): def _grid_anchors(self, grid_sizes):
...@@ -84,14 +88,15 @@ class AnchorGenerator(object): ...@@ -84,14 +88,15 @@ class AnchorGenerator(object):
shift_x, shift_y = self._create_grid_offsets(size, stride, shift_x, shift_y = self._create_grid_offsets(size, stride,
self.offset) self.offset)
shifts = paddle.stack((shift_x, shift_y, shift_x, shift_y), axis=1) shifts = paddle.stack((shift_x, shift_y, shift_x, shift_y), axis=1)
shifts = paddle.reshape(shifts, [-1, 1, 4])
base_anchors = paddle.reshape(base_anchors, [1, -1, 4])
anchors.append((shifts.reshape([-1, 1, 4]) + base_anchors.reshape( anchors.append(paddle.reshape(shifts + base_anchors, [-1, 4]))
[1, -1, 4])).reshape([-1, 4]))
return anchors return anchors
def __call__(self, input): def forward(self, input):
grid_sizes = [feature_map.shape[-2:] for feature_map in input] grid_sizes = [paddle.shape(feature_map)[-2:] for feature_map in input]
anchors_over_all_feature_maps = self._grid_anchors(grid_sizes) anchors_over_all_feature_maps = self._grid_anchors(grid_sizes)
return anchors_over_all_feature_maps return anchors_over_all_feature_maps
...@@ -105,4 +110,4 @@ class AnchorGenerator(object): ...@@ -105,4 +110,4 @@ class AnchorGenerator(object):
ratios and 5 sizes, the number of anchors is 15. ratios and 5 sizes, the number of anchors is 15.
For FPN models, `num_anchors` on every feature map is the same. For FPN models, `num_anchors` on every feature map is the same.
""" """
return self.cell_anchors[0].shape[0] return len(self.cell_anchors[0])
...@@ -108,7 +108,14 @@ class RPNHead(nn.Layer): ...@@ -108,7 +108,14 @@ class RPNHead(nn.Layer):
anchors = self.anchor_generator(rpn_feats) anchors = self.anchor_generator(rpn_feats)
rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs) # TODO: Fix batch_size > 1 when testing.
if self.training:
batch_size = im_shape.shape[0]
else:
batch_size = 1
rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs,
batch_size)
if self.training: if self.training:
loss = self.get_loss(scores, deltas, anchors, inputs) loss = self.get_loss(scores, deltas, anchors, inputs)
...@@ -116,16 +123,15 @@ class RPNHead(nn.Layer): ...@@ -116,16 +123,15 @@ class RPNHead(nn.Layer):
else: else:
return rois, rois_num, None return rois, rois_num, None
def _gen_proposal(self, scores, bbox_deltas, anchors, inputs): def _gen_proposal(self, scores, bbox_deltas, anchors, inputs, batch_size):
""" """
scores (list[Tensor]): Multi-level scores prediction scores (list[Tensor]): Multi-level scores prediction
bbox_deltas (list[Tensor]): Multi-level deltas prediction bbox_deltas (list[Tensor]): Multi-level deltas prediction
anchors (list[Tensor]): Multi-level anchors anchors (list[Tensor]): Multi-level anchors
inputs (dict): ground truth info inputs (dict): ground truth info
""" """
prop_gen = self.train_proposal if self.training else self.test_proposal prop_gen = self.train_proposal if self.training else self.test_proposal
im_shape = inputs['im_shape'] im_shape = inputs['im_shape']
batch_size = im_shape.shape[0]
rpn_rois_list = [[] for i in range(batch_size)] rpn_rois_list = [[] for i in range(batch_size)]
rpn_prob_list = [[] for i in range(batch_size)] rpn_prob_list = [[] for i in range(batch_size)]
rpn_rois_num_list = [[] for i in range(batch_size)] rpn_rois_num_list = [[] for i in range(batch_size)]
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six import six
import os import os
import numpy as np import numpy as np
import cv2 import cv2
def get_det_res(bboxes, def get_det_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0):
scores,
labels,
bbox_nums,
image_id,
label_to_cat_id_map,
bias=0):
det_res = [] det_res = []
k = 0 k = 0
for i in range(len(bbox_nums)): for i in range(len(bbox_nums)):
cur_image_id = int(image_id[i][0]) cur_image_id = int(image_id[i][0])
det_nums = bbox_nums[i] det_nums = bbox_nums[i]
for j in range(det_nums): for j in range(det_nums):
box = bboxes[k] dt = bboxes[k]
score = float(scores[k])
label = int(labels[k])
if label < 0: continue
k = k + 1 k = k + 1
xmin, ymin, xmax, ymax = box.tolist() num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
category_id = label_to_cat_id_map[label] if int(num_id) < 0:
continue
category_id = label_to_cat_id_map[int(num_id)]
w = xmax - xmin + bias w = xmax - xmin + bias
h = ymax - ymin + bias h = ymax - ymin + bias
bbox = [xmin, ymin, w, h] bbox = [xmin, ymin, w, h]
...@@ -37,8 +43,7 @@ def get_det_res(bboxes, ...@@ -37,8 +43,7 @@ def get_det_res(bboxes,
return det_res return det_res
def get_seg_res(masks, scores, labels, mask_nums, image_id, def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map):
label_to_cat_id_map):
import pycocotools.mask as mask_util import pycocotools.mask as mask_util
seg_res = [] seg_res = []
k = 0 k = 0
...@@ -46,9 +51,9 @@ def get_seg_res(masks, scores, labels, mask_nums, image_id, ...@@ -46,9 +51,9 @@ def get_seg_res(masks, scores, labels, mask_nums, image_id,
cur_image_id = int(image_id[i][0]) cur_image_id = int(image_id[i][0])
det_nums = mask_nums[i] det_nums = mask_nums[i]
for j in range(det_nums): for j in range(det_nums):
mask = masks[k] mask = masks[k].astype(np.uint8)
score = float(scores[k]) score = float(bboxes[k][1])
label = int(labels[k]) label = int(bboxes[k][0])
k = k + 1 k = k + 1
cat_id = label_to_cat_id_map[label] cat_id = label_to_cat_id_map[label]
rle = mask_util.encode( rle = mask_util.encode(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册