提交 fbe1d120 编写于 作者: S still-wait

add solov2 model

上级 a8103906
# SOLOv2 (Segmenting Objects by Locations) for instance segmentation
## Introduction
- SOLOv2 is a fast instance segmentation framework with strong performance: [https://arxiv.org/abs/2003.10152](https://arxiv.org/abs/2003.10152)
```
@misc{wang2020solov2,
title={SOLOv2: Dynamic, Faster and Stronger},
author={Xinlong Wang and Rufeng Zhang and Tao Kong and Lei Li and Chunhua Shen},
year={2020},
eprint={2003.10152},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Model Zoo
| Backbone | Multi-scale training | Lr schd | Inf time (fps) | Mask AP | Download | Configs |
| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: |
| R50-FPN | False | 1x | - | 34.7 | [model](https://paddlemodels.bj.bcebos.com/object_detection/solov2_r50_fpn_1x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/solov2/solov2_r50_fpn_1x.yml) |
architecture: SOLOv2
use_gpu: true
max_iters: 90000
snapshot_iter: 10000
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/solov2_r50_fpn_1x/model_final
num_classes: 81
SOLOv2:
backbone: ResNet
fpn: FPN
bbox_head: SOLOv2Head
mask_head: SOLOv2MaskHead
batch_size: 2
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
reverse_out: True
SOLOv2Head:
seg_feat_channels: 512
stacked_convs: 4
num_grids: [40, 36, 24, 16, 12]
kernel_out_channels: 256
SOLOv2MaskHead:
out_channels: 128
start_level: 0
end_level: 3
num_classes: 256
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: 'solov2_reader.yml'
TrainReader:
batch_size: 2
worker_num: 2
inputs_def:
fields: ['image', 'im_id', 'gt_segm']
dataset:
!COCODataSet
dataset_dir: dataset/coco
anno_path: annotations/instances_train2017.json
image_dir: train2017
sample_transforms:
- !DecodeImage
to_rgb: true
- !Poly2Mask {}
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
resize_box: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
- !Gt2Solov2Target
num_grids: [40, 36, 24, 16, 12]
scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]]
coord_sigma: 0.2
shuffle: True
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
worker_num: 2
TestReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!ImageFolder
anno_path: dataset/coco/annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
......@@ -30,6 +30,7 @@ RESIZE_SCALE_SET = {
'RCNN',
'RetinaNet',
'FCOS',
'SOLOv2',
}
SUPPORT_MODELS = {
......@@ -41,6 +42,7 @@ SUPPORT_MODELS = {
'Face',
'TTF',
'FCOS',
'SOLOv2',
}
......@@ -85,7 +87,8 @@ class Resize(object):
max_size,
use_cv2=True,
image_shape=None,
interp=cv2.INTER_LINEAR):
interp=cv2.INTER_LINEAR,
resize_box=False):
self.target_size = target_size
self.max_size = max_size
self.image_shape = image_shape
......@@ -251,7 +254,7 @@ class PadStride(object):
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
im_info['resize_shape'] = padding_im.shape[1:]
im_info['pad_shape'] = padding_im.shape[1:]
return padding_im, im_info
......@@ -268,23 +271,29 @@ def create_inputs(im, im_info, model_arch='YOLO'):
inputs['image'] = im
origin_shape = list(im_info['origin_shape'])
resize_shape = list(im_info['resize_shape'])
pad_shape = list(im_info['pad_shape']) if 'pad_shape' in im_info else list(
im_info['resize_shape'])
scale_x, scale_y = im_info['scale']
if 'YOLO' in model_arch:
im_size = np.array([origin_shape]).astype('int32')
inputs['im_size'] = im_size
elif 'RetinaNet' or 'EfficientDet' in model_arch:
elif 'RetinaNet' in model_arch or 'EfficientDet' in model_arch:
scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32')
im_info = np.array([pad_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info
elif ('RCNN' in model_arch) or ('FCOS' in model_arch):
scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32')
im_info = np.array([pad_shape + [scale]]).astype('float32')
im_shape = np.array([origin_shape + [1.]]).astype('float32')
inputs['im_info'] = im_info
inputs['im_shape'] = im_shape
elif 'TTF' in model_arch:
scale_factor = np.array([scale_x, scale_y] * 2).astype('float32')
inputs['scale_factor'] = scale_factor
elif 'SOLOv2' in model_arch:
scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info
return inputs
......@@ -405,10 +414,15 @@ def visualize(image_file,
results,
labels,
mask_resolution=14,
output_dir='output/'):
output_dir='output/',
threshold=0.5):
# visualize the predict result
im = visualize_box_mask(
image_file, results, labels, mask_resolution=mask_resolution)
image_file,
results,
labels,
mask_resolution=mask_resolution,
threshold=threshold)
img_name = os.path.split(image_file)[-1]
if not os.path.exists(output_dir):
os.makedirs(output_dir)
......@@ -516,6 +530,11 @@ class Detector():
ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms))
if self.config.arch == 'SOLOv2':
return dict(
segm=np.array(outs[2]),
label=np.array(outs[0]),
score=np.array(outs[1]))
np_boxes = np.array(outs[0])
if self.config.mask_resolution is not None:
np_masks = np.array(outs[1])
......@@ -539,6 +558,13 @@ class Detector():
for i in range(repeats):
self.predictor.zero_copy_run()
output_names = self.predictor.get_output_names()
if self.config.arch == 'SOLOv2':
np_label = self.predictor.get_output_tensor(output_names[
0]).copy_to_cpu()
np_score = self.predictor.get_output_tensor(output_names[
1]).copy_to_cpu()
np_segms = self.predictor.get_output_tensor(output_names[
2]).copy_to_cpu()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.config.mask_resolution is not None:
......@@ -552,6 +578,9 @@ class Detector():
# do not perform postprocess in benchmark mode
results = []
if not run_benchmark:
if self.config.arch == 'SOLOv2':
return dict(segm=np_segms, label=np_label, score=np_score)
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
print('[WARNNING] No object detected.')
results = {'boxes': np.array([])}
......@@ -579,7 +608,8 @@ def predict_image():
results,
detector.config.labels,
mask_resolution=detector.config.mask_resolution,
output_dir=FLAGS.output_dir)
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold)
def predict_video(camera_id):
......
......@@ -18,20 +18,22 @@ from __future__ import division
import cv2
import numpy as np
from PIL import Image, ImageDraw
from scipy import ndimage
def visualize_box_mask(im, results, labels, mask_resolution=14):
"""
def visualize_box_mask(im, results, labels, mask_resolution=14, threshold=0.5):
"""
Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
labels (list): labels:['class1', ..., 'classn']
mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution]
threshold (float): Threshold of score.
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
if isinstance(im, str):
im = Image.open(im).convert('RGB')
......@@ -46,15 +48,23 @@ def visualize_box_mask(im, results, labels, mask_resolution=14):
resolution=mask_resolution)
if 'boxes' in results:
im = draw_box(im, results['boxes'], labels)
if 'segm' in results:
im = draw_segm(
im,
results['segm'],
results['label'],
results['score'],
labels,
threshold=threshold)
return im
def get_color_map_list(num_classes):
"""
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
color_map (list): RGB color list
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
......@@ -71,9 +81,9 @@ def get_color_map_list(num_classes):
def expand_boxes(boxes, scale=0.0):
"""
"""
Args:
boxes (np.ndarray): shape:[N,4], N:number of box
boxes (np.ndarray): shape:[N,4], N:number of box,
matix element:[x_min, y_min, x_max, y_max]
scale (float): scale of boxes
Returns:
......@@ -94,17 +104,17 @@ def expand_boxes(boxes, scale=0.0):
def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
"""
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
np_masks (np.ndarray): shape:[N, class_num, resolution, resolution]
labels (list): labels:['class1', ..., 'classn']
resolution (int): shape of a mask is:[resolution, resolution]
threshold (float): threshold of mask
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
color_list = get_color_map_list(len(labels))
scale = (resolution + 2.0) / resolution
......@@ -149,14 +159,14 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
def draw_box(im, np_boxes, labels):
"""
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
Returns:
im (PIL.Image.Image): visualized image
im (PIL.Image.Image): visualized image
"""
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
......@@ -186,3 +196,43 @@ def draw_box(im, np_boxes, labels):
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im
def draw_segm(im,
np_segms,
np_label,
np_score,
labels,
threshold=0.5,
alpha=0.7):
"""
Draw segmentation on image
"""
mask_color_id = 0
w_ratio = .4
color_list = get_color_map_list(len(labels))
im = np.array(im).astype('float32')
clsid2color = {}
np_segms = np_segms.astype(np.uint8)
for i in range(np_segms.shape[0]):
mask, score, clsid = np_segms[i], np_score[i], np_label[i] + 1
if score < threshold:
continue
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid]
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
color_mask = np.array(color_mask)
im[idx[0], idx[1], :] *= 1.0 - alpha
im[idx[0], idx[1], :] += alpha * color_mask
center_y, center_x = ndimage.measurements.center_of_mass(mask)
label_text = "{}".format(labels[clsid])
print(label_text)
print(center_y, center_x)
vis_pos = (max(int(center_x) - 10, 0), int(center_y))
cv2.putText(im, label_text, vis_pos, cv2.FONT_HERSHEY_COMPLEX, 0.3,
(255, 255, 255))
return Image.fromarray(im.astype('uint8'))
......@@ -24,6 +24,7 @@ except Exception:
import logging
import cv2
import numpy as np
from scipy import ndimage
from .operators import register_op, BaseOperator
from .op_helper import jaccard_overlap, gaussian2D
......@@ -37,6 +38,7 @@ __all__ = [
'Gt2YoloTarget',
'Gt2FCOSTarget',
'Gt2TTFTarget',
'Gt2Solov2Target',
]
......@@ -88,6 +90,13 @@ class PadBatch(BaseOperator):
(1, max_shape[1], max_shape[2]), dtype=np.float32)
padding_sem[:, :im_h, :im_w] = semantic
data['semantic'] = padding_sem
if 'gt_segm' in data.keys() and data['gt_segm'] is not None:
gt_segm = data['gt_segm']
padding_segm = np.zeros(
(gt_segm.shape[0], max_shape[1], max_shape[2]),
dtype=np.uint8)
padding_segm[:, :im_h, :im_w] = gt_segm
data['gt_segm'] = padding_segm
return samples
......@@ -590,3 +599,154 @@ class Gt2TTFTarget(BaseOperator):
heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
masked_heatmap, masked_gaussian)
return heatmap
@register_op
class Gt2Solov2Target(BaseOperator):
"""Assign mask target and labels in SOLOv2 network.
Args:
num_grids (list): The list of feature map grids size.
scale_ranges (list): The list of mask boundary range.
coord_sigma (float): The coefficient of coordinate area length.
sampling_ratio (float): The ratio of down sampling.
"""
def __init__(self,
num_grids=[40, 36, 24, 16, 12],
scale_ranges=[[1, 96], [48, 192], [96, 384], [192, 768],
[384, 2048]],
coord_sigma=0.2,
sampling_ratio=4.0):
super(Gt2Solov2Target, self).__init__()
self.num_grids = num_grids
self.scale_ranges = scale_ranges
self.coord_sigma = coord_sigma
self.sampling_ratio = sampling_ratio
def _scale_size(self, im, scale):
h, w = im.shape[:2]
new_size = (int(w * float(scale) + 0.5), int(h * float(scale) + 0.5))
resized_img = cv2.resize(
im, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
return resized_img
def __call__(self, samples, context=None):
for sample in samples:
gt_bboxes_raw = sample['gt_bbox']
gt_labels_raw = sample['gt_class']
im_c, im_h, im_w = sample['image'].shape[:]
gt_masks_raw = sample['gt_segm'].astype(np.uint8)
mask_feat_size = [
int(im_h / self.sampling_ratio), int(im_w / self.sampling_ratio)
]
gt_areas = np.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
(gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_ind_label_list = []
grid_offset = []
idx = 0
for (lower_bound, upper_bound), num_grid \
in zip(self.scale_ranges, self.num_grids):
hit_indices = ((gt_areas >= lower_bound) &
(gt_areas <= upper_bound)).nonzero()[0]
num_ins = len(hit_indices)
ins_label = []
grid_order = []
cate_label = np.zeros([num_grid, num_grid], dtype=np.int64)
ins_ind_label = np.zeros([num_grid**2], dtype=np.bool)
if num_ins == 0:
ins_label = np.zeros(
[1, mask_feat_size[0], mask_feat_size[1]],
dtype=np.uint8)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray([0])
grid_offset.append(1)
idx += 1
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices, ...]
half_ws = 0.5 * (
gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.coord_sigma
half_hs = 0.5 * (
gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.coord_sigma
for seg_mask, gt_label, half_h, half_w in zip(
gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() == 0:
continue
# mass center
upsampled_size = (mask_feat_size[0] * 4,
mask_feat_size[1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(
seg_mask)
coord_w = int(
(center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int(
(center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0,
int(((center_h - half_h) / upsampled_size[0])
// (1. / num_grid)))
down_box = min(num_grid - 1,
int(((center_h + half_h) / upsampled_size[0])
// (1. / num_grid)))
left_box = max(0,
int(((center_w - half_w) / upsampled_size[1])
// (1. / num_grid)))
right_box = min(num_grid - 1,
int(((center_w + half_w) /
upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h - 1)
down = min(down_box, coord_h + 1)
left = max(coord_w - 1, left_box)
right = min(right_box, coord_w + 1)
cate_label[top:(down + 1), left:(right + 1)] = gt_label
seg_mask = self._scale_size(
seg_mask, scale=1. / self.sampling_ratio)
for i in range(top, down + 1):
for j in range(left, right + 1):
label = int(i * num_grid + j)
cur_ins_label = np.zeros(
[mask_feat_size[0], mask_feat_size[1]],
dtype=np.uint8)
cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[
1]] = seg_mask
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(label)
if ins_label == []:
ins_label = np.zeros(
[1, mask_feat_size[0], mask_feat_size[1]],
dtype=np.uint8)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray([0])
grid_offset.append(1)
else:
ins_label = np.stack(ins_label, axis=0)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(grid_order)
assert len(grid_order) > 0
grid_offset.append(len(grid_order))
idx += 1
ins_ind_labels = np.concatenate([
ins_ind_labels_level_img
for ins_ind_labels_level_img in ins_ind_label_list
])
fg_num = np.sum(ins_ind_labels)
sample['fg_num'] = fg_num
sample['grid_offset'] = np.asarray(grid_offset).astype(np.int32)
return samples
......@@ -272,7 +272,8 @@ class ResizeImage(BaseOperator):
target_size=0,
max_size=0,
interp=cv2.INTER_LINEAR,
use_cv2=True):
use_cv2=True,
resize_box=False):
"""
Rescale image to the specified target size, and capped at max_size
if max_size != 0.
......@@ -285,11 +286,13 @@ class ResizeImage(BaseOperator):
interp (int): the interpolation method
use_cv2 (bool): use the cv2 interpolation method or use PIL
interpolation method
resize_box (bool): whether resize ground truth bbox annotations.
"""
super(ResizeImage, self).__init__()
self.max_size = int(max_size)
self.interp = int(interp)
self.use_cv2 = use_cv2
self.resize_box = resize_box
if not (isinstance(target_size, int) or isinstance(target_size, list)):
raise TypeError(
"Type of target_size is invalid. Must be Integer or List, now is {}".
......@@ -348,18 +351,6 @@ class ResizeImage(BaseOperator):
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
if 'semantic' in sample.keys() and sample['semantic'] is not None:
semantic = sample['semantic']
semantic = cv2.resize(
semantic.astype('float32'),
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
semantic = np.asarray(semantic).astype('int32')
semantic = np.expand_dims(semantic, 0)
sample['semantic'] = semantic
else:
if self.max_size != 0:
raise TypeError(
......@@ -370,6 +361,38 @@ class ResizeImage(BaseOperator):
im = im.resize((int(resize_w), int(resize_h)), self.interp)
im = np.array(im)
sample['image'] = im
sample['scale_factor'] = [im_scale_x, im_scale_y] * 2
if 'gt_bbox' in sample and self.resize_box and len(sample[
'gt_bbox']) > 0:
bboxes = sample['gt_bbox'] * sample['scale_factor']
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, resize_w - 1)
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, resize_h - 1)
sample['gt_bbox'] = bboxes
if 'semantic' in sample.keys() and sample['semantic'] is not None:
semantic = sample['semantic']
semantic = cv2.resize(
semantic.astype('float32'),
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
semantic = np.asarray(semantic).astype('int32')
semantic = np.expand_dims(semantic, 0)
sample['semantic'] = semantic
if 'gt_segm' in sample and len(sample['gt_segm']) > 0:
masks = [
cv2.resize(
gt_segm,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_NEAREST)
for gt_segm in sample['gt_segm']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
return sample
......@@ -473,7 +496,6 @@ class RandomFlipImage(BaseOperator):
if self.is_mask_flip and len(sample['gt_poly']) != 0:
sample['gt_poly'] = self.flip_segms(sample['gt_poly'],
height, width)
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = self.flip_keypoint(
sample['gt_keypoint'], width)
......@@ -482,6 +504,9 @@ class RandomFlipImage(BaseOperator):
'semantic'] is not None:
sample['semantic'] = sample['semantic'][:, ::-1]
if 'gt_segm' in sample.keys() and sample['gt_segm'] is not None:
sample['gt_segm'] = sample['gt_segm'][:, :, ::-1]
sample['flipped'] = True
sample['image'] = im
sample = samples if batch_input else samples[0]
......@@ -2557,3 +2582,41 @@ class DebugVisibleImage(BaseOperator):
save_path = os.path.join(self.output_dir, out_file_name)
image.save(save_path, quality=95)
return sample
@register_op
class Poly2Mask(BaseOperator):
"""
gt poly to mask annotations
"""
def __init__(self):
super(Poly2Mask, self).__init__()
import pycocotools.mask as maskUtils
self.maskutils = maskUtils
def _poly2mask(self, mask_ann, img_h, img_w):
if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
rle = self.maskutils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
else:
# rle
rle = mask_ann
mask = self.maskutils.decode(rle)
return mask
def __call__(self, sample, context=None):
assert 'gt_poly' in sample
im_h = sample['h']
im_w = sample['w']
masks = [
self._poly2mask(gt_poly, im_h, im_w)
for gt_poly in sample['gt_poly']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
return sample
......@@ -22,6 +22,7 @@ from . import roi_extractors
from . import roi_heads
from . import ops
from . import target_assigners
from . import mask_head
from .anchor_heads import *
from .architectures import *
......@@ -30,3 +31,4 @@ from .roi_extractors import *
from .roi_heads import *
from .ops import *
from .target_assigners import *
from .mask_head import *
......@@ -21,6 +21,7 @@ from . import fcos_head
from . import corner_head
from . import efficient_head
from . import ttf_head
from . import solov2_head
from .rpn_head import *
from .yolo_head import *
......@@ -29,3 +30,4 @@ from .fcos_head import *
from .corner_head import *
from .efficient_head import *
from .ttf_head import *
from .solov2_head import *
此差异已折叠。
......@@ -29,6 +29,7 @@ from . import fcos
from . import cornernet_squeeze
from . import ttfnet
from . import htc
from . import solov2
from .faster_rcnn import *
from .mask_rcnn import *
......@@ -45,3 +46,4 @@ from .fcos import *
from .cornernet_squeeze import *
from .ttfnet import *
from .htc import *
from .solov2 import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
__all__ = ['SOLOv2']
@register
class SOLOv2(object):
"""
SOLOv2 network, see https://arxiv.org/abs/2003.10152
Args:
backbone (object): an backbone instance
fpn (object): feature pyramid network instance
bbox_head (object): an `SOLOv2Head` instance
mask_head (object): an `SOLOv2MaskHead` instance
batch_size (int): batch size.
"""
__category__ = 'architecture'
__inject__ = ['backbone', 'fpn', 'bbox_head', 'mask_head']
def __init__(self,
backbone,
fpn=None,
bbox_head='SOLOv2Head',
mask_head='SOLOv2MaskHead',
batch_size=1):
super(SOLOv2, self).__init__()
self.backbone = backbone
self.fpn = fpn
self.bbox_head = bbox_head
self.mask_head = mask_head
self.batch_size = batch_size
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
mixed_precision_enabled = mixed_precision_global_state() is not None
# cast inputs to FP16
if mixed_precision_enabled:
im = fluid.layers.cast(im, 'float16')
body_feats = self.backbone(im)
if self.fpn is not None:
body_feats, spatial_scale = self.fpn.get_output(body_feats)
if isinstance(body_feats, OrderedDict):
body_feat_names = list(body_feats.keys())
body_feats = [body_feats[name] for name in body_feat_names]
# cast features back to FP32
if mixed_precision_enabled:
body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats]
if not mode == 'train':
self.batch_size = 1
mask_feat_pred = self.mask_head.get_output(body_feats, self.batch_size)
if mode == 'train':
ins_labels = []
cate_labels = []
grid_orders = []
fg_num = feed_vars['fg_num']
grid_offset = feed_vars['grid_offset']
for i in range(5):
ins_label = 'ins_label{}'.format(i)
if ins_label in feed_vars:
ins_labels.append(feed_vars[ins_label])
cate_label = 'cate_label{}'.format(i)
if cate_label in feed_vars:
cate_labels.append(feed_vars[cate_label])
grid_order = 'grid_order{}'.format(i)
if grid_order in feed_vars:
grid_orders.append(feed_vars[grid_order])
cate_preds, kernel_preds = self.bbox_head.get_outputs(
body_feats, batch_size=self.batch_size)
losses = self.bbox_head.get_loss(
cate_preds, kernel_preds, mask_feat_pred, ins_labels,
cate_labels, grid_orders, fg_num, grid_offset, self.batch_size)
total_loss = fluid.layers.sum(list(losses.values()))
losses.update({'loss': total_loss})
return losses
else:
im_info = feed_vars['im_info']
outs = self.bbox_head.get_outputs(
body_feats, is_eval=True, batch_size=self.batch_size)
seg_inputs = outs + (mask_feat_pred, im_info)
return self.bbox_head.get_prediction(*seg_inputs)
def _inputs_def(self, image_shape, fields):
im_shape = [None] + image_shape
# yapf: disable
inputs_def = {
'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
}
if 'gt_segm' in fields:
targets_def = {
'ins_label0': {'shape': [None, None, None], 'dtype': 'int32', 'lod_level': 1},
'ins_label1': {'shape': [None, None, None], 'dtype': 'int32', 'lod_level': 1},
'ins_label2': {'shape': [None, None, None], 'dtype': 'int32', 'lod_level': 1},
'ins_label3': {'shape': [None, None, None], 'dtype': 'int32', 'lod_level': 1},
'ins_label4': {'shape': [None, None, None], 'dtype': 'int32', 'lod_level': 1},
'cate_label0': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'cate_label1': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'cate_label2': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'cate_label3': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'cate_label4': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'grid_order0': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'grid_order1': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'grid_order2': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'grid_order3': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'grid_order4': {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
'fg_num': {'shape': [None], 'dtype': 'int32', 'lod_level': 0},
'grid_offset': {'shape': [None, 5], 'dtype': 'int32', 'lod_level': 0},
}
# yapf: enable
inputs_def.update(targets_def)
return inputs_def
def build_inputs(
self,
image_shape=[3, None, None],
fields=['image', 'im_id', 'gt_segm'], # for train
use_dataloader=True,
iterable=False):
inputs_def = self._inputs_def(image_shape, fields)
if 'gt_segm' in fields:
fields.remove('gt_segm')
fields.extend(['fg_num', 'grid_offset'])
for i in range(5):
fields.extend([
'ins_label%d' % i, 'cate_label%d' % i, 'grid_order%d' % i
])
feed_vars = OrderedDict([(key, fluid.data(
name=key,
shape=inputs_def[key]['shape'],
dtype=inputs_def[key]['dtype'],
lod_level=inputs_def[key]['lod_level'])) for key in fields])
loader = fluid.io.DataLoader.from_generator(
feed_list=list(feed_vars.values()),
capacity=16,
use_double_buffer=True,
iterable=iterable) if use_dataloader else None
return feed_vars, loader
def train(self, feed_vars):
return self.build(feed_vars, mode='train')
def eval(self, feed_vars):
return self.build(feed_vars, mode='test')
def test(self, feed_vars):
return self.build(feed_vars, mode='test')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import solo_mask_head
from .solo_mask_head import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import fluid
from ppdet.core.workspace import register
from ppdet.modeling.ops import ConvNorm, DeformConvNorm
__all__ = ['SOLOv2MaskHead']
@register
class SOLOv2MaskHead(object):
"""
SOLOv2MaskHead
Args:
out_channels (int): The channel number of output variable.
start_level (int): The position where the input starts.
end_level (int): The position where the input ends.
num_classes (int): Number of classes in SOLOv2MaskHead output.
use_dcn_in_tower: Whether to use dcn in tower or not.
"""
__shared__ = ['num_classes']
def __init__(self,
out_channels=128,
start_level=0,
end_level=3,
num_classes=81,
use_dcn_in_tower=False):
super(SOLOv2MaskHead, self).__init__()
assert start_level >= 0 and end_level >= start_level
self.out_channels = out_channels
self.start_level = start_level
self.end_level = end_level
self.num_classes = num_classes
self.use_dcn_in_tower = use_dcn_in_tower
self.conv_type = [ConvNorm, DeformConvNorm]
def _convs_levels(self, conv_feat, level, name=None):
conv_func = self.conv_type[0]
if self.use_dcn_in_tower:
conv_func = self.conv_type[1]
if level == 0:
return conv_func(
input=conv_feat,
num_filters=self.out_channels,
filter_size=3,
stride=1,
norm_type='gn',
norm_groups=32,
freeze_norm=False,
act='relu',
initializer=fluid.initializer.NormalInitializer(scale=0.01),
norm_name=name + '.conv' + str(level) + '.gn',
name=name + '.conv' + str(level))
for j in range(level):
conv_feat = conv_func(
input=conv_feat,
num_filters=self.out_channels,
filter_size=3,
stride=1,
norm_type='gn',
norm_groups=32,
freeze_norm=False,
act='relu',
initializer=fluid.initializer.NormalInitializer(scale=0.01),
norm_name=name + '.conv' + str(j) + '.gn',
name=name + '.conv' + str(j))
conv_feat = fluid.layers.resize_bilinear(
conv_feat,
scale=2,
name='upsample' + str(level) + str(j),
align_corners=False,
align_mode=0)
return conv_feat
def _conv_pred(self, conv_feat):
conv_func = self.conv_type[0]
if self.use_dcn_in_tower:
conv_func = self.conv_type[1]
conv_feat = conv_func(
input=conv_feat,
num_filters=self.num_classes,
filter_size=1,
stride=1,
norm_type='gn',
norm_groups=32,
freeze_norm=False,
act='relu',
initializer=fluid.initializer.NormalInitializer(scale=0.01),
norm_name='mask_feat_head.conv_pred.0.gn',
name='mask_feat_head.conv_pred.0')
return conv_feat
def get_output(self, inputs, batch_size=1):
"""
Get SOLOv2MaskHead output.
Args:
inputs(list[Variable]): feature map from each necks with shape of [N, C, H, W]
batch_size (int): batch size
Returns:
ins_pred(Variable): Output of SOLOv2MaskHead head
"""
range_level = self.end_level - self.start_level + 1
feature_add_all_level = self._convs_levels(
inputs[0], 0, name='mask_feat_head.convs_all_levels.0')
for i in range(1, range_level):
input_p = inputs[i]
if i == 3:
input_feat = input_p
x_range = paddle.linspace(
-1, 1, fluid.layers.shape(input_feat)[-1], dtype='float32')
y_range = paddle.linspace(
-1, 1, fluid.layers.shape(input_feat)[-2], dtype='float32')
y, x = paddle.tensor.meshgrid([y_range, x_range])
x = fluid.layers.unsqueeze(x, [0, 1])
y = fluid.layers.unsqueeze(y, [0, 1])
y = fluid.layers.expand(y, expand_times=[batch_size, 1, 1, 1])
x = fluid.layers.expand(x, expand_times=[batch_size, 1, 1, 1])
coord_feat = fluid.layers.concat([x, y], axis=1)
input_p = fluid.layers.concat([input_p, coord_feat], axis=1)
feature_add_all_level = fluid.layers.elementwise_add(
feature_add_all_level,
self._convs_levels(
input_p,
i,
name='mask_feat_head.convs_all_levels.{}'.format(i)))
ins_pred = self._conv_pred(feature_add_all_level)
return ins_pred
......@@ -17,6 +17,7 @@ from numbers import Integral
import math
import six
import paddle
from paddle import fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.initializer import NumpyArrayInitializer
......@@ -1263,27 +1264,27 @@ class LibraBBoxAssigner(object):
rois = create_tmp_var(
fluid.default_main_program(),
name=None, #'rois',
name=None,
dtype='float32',
shape=[-1, 4], )
bbox_inside_weights = create_tmp_var(
fluid.default_main_program(),
name=None, #'bbox_inside_weights',
name=None,
dtype='float32',
shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
bbox_outside_weights = create_tmp_var(
fluid.default_main_program(),
name=None, #'bbox_outside_weights',
name=None,
dtype='float32',
shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
bbox_targets = create_tmp_var(
fluid.default_main_program(),
name=None, #'bbox_targets',
name=None,
dtype='float32',
shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
labels_int32 = create_tmp_var(
fluid.default_main_program(),
name=None, #'labels_int32',
name=None,
dtype='int32',
shape=[-1, 1], )
......@@ -1565,3 +1566,79 @@ class RetinaOutputDecoder(object):
self.nms_top_k = pre_nms_top_n
self.keep_top_k = detections_per_im
self.nms_eta = nms_eta
@register
@serializable
class MaskMatrixNMS(object):
"""
Matrix NMS for multi-class masks.
Args:
kernel (str): 'linear' or 'gaussian'
sigma (float): std in gaussian method
Input:
seg_masks (Variable): shape (n, h, w), segmentation feature maps
cate_labels (Variable): shape (n), mask labels in descending order
cate_scores (Variable): shape (n), mask scores in descending order
sum_masks (Variable): The sum of seg_masks
Returns:
Variable: cate_scores_update, tensors of shape (n)
"""
def __init__(self, kernel='gaussian', sigma=2.0):
super(MaskMatrixNMS, self).__init__()
self.kernel = kernel
self.sigma = sigma
def __call__(self, seg_masks, cate_labels, cate_scores, sum_masks=None):
n_samples = fluid.layers.shape(cate_labels)
seg_masks = fluid.layers.reshape(seg_masks, shape=(n_samples, -1))
# inter.
inter_matrix = paddle.mm(seg_masks,
fluid.layers.transpose(seg_masks, [1, 0]))
# union.
sum_masks_x = fluid.layers.reshape(
fluid.layers.expand(
sum_masks, expand_times=[n_samples]),
shape=[n_samples, n_samples])
# iou.
iou_matrix = (inter_matrix / (sum_masks_x + fluid.layers.transpose(
sum_masks_x, [1, 0]) - inter_matrix))
iou_matrix = paddle.tensor.triu(iou_matrix, diagonal=1)
# label_specific matrix.
cate_labels_x = fluid.layers.reshape(
fluid.layers.expand(
cate_labels, expand_times=[n_samples]),
shape=[n_samples, n_samples])
label_matrix = fluid.layers.cast(
(cate_labels_x == fluid.layers.transpose(cate_labels_x, [1, 0])),
'float32')
label_matrix = paddle.tensor.triu(label_matrix, diagonal=1)
# IoU compensation
compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
compensate_iou = fluid.layers.reshape(
fluid.layers.expand(
compensate_iou, expand_times=[n_samples]),
shape=[n_samples, n_samples])
compensate_iou = fluid.layers.transpose(compensate_iou, [1, 0])
# IoU decay
decay_iou = iou_matrix * label_matrix
# matrix nms
if self.kernel == 'gaussian':
decay_matrix = fluid.layers.exp(-1 * self.sigma * (decay_iou**2))
compensate_matrix = fluid.layers.exp(-1 * self.sigma *
(compensate_iou**2))
decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
axis=0)
elif self.kernel == 'linear':
decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
decay_coefficient = paddle.min(decay_matrix, axis=0)
else:
raise NotImplementedError
# update the score.
cate_scores_update = cate_scores * decay_coefficient
return cate_scores_update
......@@ -164,6 +164,47 @@ def mask_eval(results,
cocoapi_eval(outfile, 'segm', coco_gt=coco_gt)
def segm_eval(results, anno_file, outfile, save_only=False):
assert 'segm' in results[0]
assert outfile.endswith('.json')
from pycocotools.coco import COCO
coco_gt = COCO(anno_file)
clsid2catid = {i: v for i, v in enumerate(coco_gt.getCatIds())}
segm_results = []
for t in results:
im_id = int(t['im_id'][0][0])
segs = t['segm']
for mask in segs:
catid = int(clsid2catid[mask[0]])
masks = mask[1]
mask_score = masks[1]
segm = masks[0]
segm['counts'] = segm['counts'].decode('utf8')
coco_res = {
'image_id': im_id,
'category_id': catid,
'segmentation': segm,
'score': mask_score
}
segm_results.append(coco_res)
if len(segm_results) == 0:
logger.warning("The number of valid mask detected is zero.\n \
Please use reasonable model and check input data.")
return
with open(outfile, 'w') as f:
json.dump(segm_results, f)
if save_only:
logger.info('The mask result is saved to {} and do not '
'evaluate the mAP.'.format(outfile))
return
map_stats = cocoapi_eval(outfile, 'segm', coco_gt=coco_gt)
return map_stats
def cocoapi_eval(jsonfile,
style,
coco_gt=None,
......@@ -374,6 +415,43 @@ def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5):
return segm_res
def segm2out(results, clsid2catid, thresh_binarize=0.5):
import pycocotools.mask as mask_util
segm_res = []
# for each batch
for t in results:
segms = t['segm'][0]
clsid_labels = t['cate_label'][0]
clsid_scores = t['cate_score'][0]
lengths = segms.shape[0]
im_id = int(t['im_id'][0][0])
im_shape = t['im_shape'][0][0]
if lengths == 0 or segms is None:
continue
# for each sample
for i in range(lengths - 1):
im_h = int(im_shape[0])
im_w = int(im_shape[1])
clsid = int(clsid_labels[i])
catid = clsid2catid[clsid]
score = clsid_scores[i]
mask = segms[i]
segm = mask_util.encode(
np.array(
mask[:, :, np.newaxis], order='F'))[0]
segm['counts'] = segm['counts'].decode('utf8')
coco_res = {
'image_id': im_id,
'category_id': catid,
'segmentation': segm,
'score': score
}
segm_res.append(coco_res)
return segm_res
def expand_boxes(boxes, scale):
"""
Expand an array of boxes by a given scale.
......
......@@ -94,6 +94,25 @@ def clean_res(result, keep_name_list):
return clean_result
def get_masks(result):
import pycocotools.mask as mask_util
if result is None:
return {}
seg_pred = result['segm'][0].astype(np.uint8)
cate_label = result['cate_label'][0].astype(np.int)
cate_score = result['cate_score'][0].astype(np.float)
num_ins = seg_pred.shape[0]
masks = []
for idx in range(num_ins - 1):
cur_mask = seg_pred[idx, ...]
rle = mask_util.encode(
np.array(
cur_mask[:, :, np.newaxis], order='F'))[0]
rst = (rle, cate_score[idx])
masks.append([cate_label[idx], rst])
return masks
def eval_run(exe,
compile_program,
loader,
......@@ -163,11 +182,13 @@ def eval_run(exe,
corner_post_process(res, post_config, cfg.num_classes)
if 'TTFNet' in cfg.architecture:
res['bbox'][1].append([len(res['bbox'][0])])
if 'segm' in res:
res['segm'] = get_masks(res)
results.append(res)
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
iter_id += 1
if len(res['bbox'][1]) == 0:
if 'bbox' not in res or len(res['bbox'][1]) == 0:
has_bbox = False
images_num += len(res['bbox'][1][0]) if has_bbox else 1
except (StopIteration, fluid.core.EOFException):
......@@ -198,7 +219,7 @@ def eval_results(results,
"""Evaluation for evaluation program results"""
box_ap_stats = []
if metric == 'COCO':
from ppdet.utils.coco_eval import proposal_eval, bbox_eval, mask_eval
from ppdet.utils.coco_eval import proposal_eval, bbox_eval, mask_eval, segm_eval
anno_file = dataset.get_anno()
with_background = dataset.with_background
if 'proposal' in results[0]:
......@@ -225,6 +246,14 @@ def eval_results(results,
output = os.path.join(output_directory, 'mask.json')
mask_eval(
results, anno_file, output, resolution, save_only=save_only)
if 'segm' in results[0]:
output = 'segm.json'
if output_directory:
output = os.path.join(output_directory, output)
mask_ap_stats = segm_eval(
results, anno_file, output, save_only=save_only)
if len(box_ap_stats) == 0:
box_ap_stats = mask_ap_stats
else:
if 'accum_map' in results[-1]:
res = np.mean(results[-1]['accum_map'][0])
......
......@@ -132,9 +132,6 @@ def main():
extra_keys)
sub_eval_prog = sub_eval_prog.clone(True)
#if 'weights' in cfg:
# checkpoint.load_params(exe, sub_eval_prog, cfg.weights)
# load model
exe.run(startup_prog)
if 'weights' in cfg:
......@@ -146,7 +143,6 @@ def main():
results = eval_run(exe, compile_program, loader, keys, values, cls, cfg,
sub_eval_prog, sub_keys, sub_values, resolution)
#print(cfg['EvalReader']['dataset'].__dict__)
# evaluation
# if map_type not set, use default 11point, only use in VOC eval
map_type = cfg.map_type if 'map_type' in cfg else '11point'
......
......@@ -46,11 +46,13 @@ TRT_MIN_SUBGRAPH = {
'Face': 3,
'TTFNet': 3,
'FCOS': 3,
'SOLOv2': 3,
}
RESIZE_SCALE_SET = {
'RCNN',
'RetinaNet',
'FCOS',
'SOLOv2',
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册