提交 5b18edf5 编写于 作者: S sunyanfang01

fix for review

上级 612acfa8
......@@ -73,7 +73,7 @@ image_pretrain = {
}
obj365_pretrain = {
'ResNet50_vd_dcn_db_obj365':
'ResNet50_vd_obj365':
'https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar',
}
......@@ -127,13 +127,27 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
if hasattr(paddlex, 'pretrain_dir'):
new_save_dir = paddlex.pretrain_dir
if backbone == 'ResNet50_vd':
backbone = 'ResNet50_vd_dcn_db_obj365'
assert backbone in obj365_pretrain, "There is not Object365 pretrain weights for {}, you may try ImageNet.".format(
backbone = 'ResNet50_vd_obj365'
assert backbone in obj365_pretrain, "There is not Object365 pretrain weights for {}, try use pretrain_weights='IMAGENET'".format(
backbone)
url = obj365_pretrain[backbone]
fname = osp.split(url)[-1].split('.')[0]
paddlex.utils.download_and_decompress(url, path=new_save_dir)
return osp.join(new_save_dir, fname)
# url = obj365_pretrain[backbone]
# fname = osp.split(url)[-1].split('.')[0]
# paddlex.utils.download_and_decompress(url, path=new_save_dir)
# return osp.join(new_save_dir, fname)
try:
hub.download(backbone, save_path=new_save_dir)
except Exception as e:
if isinstance(hub.ResourceNotFoundError):
raise Exception("Resource for backbone {} not found".format(
backbone))
elif isinstance(hub.ServerConnectionError):
raise Exception(
"Cannot get reource for backbone {}, please check your internet connecgtion"
.format(backbone))
else:
raise Exception(
"Unexpected error, please make sure paddlehub >= 1.6.2")
return osp.join(new_save_dir, backbone)
elif flag == 'COCO':
new_save_dir = save_dir
if hasattr(paddlex, 'pretrain_dir'):
......
......@@ -44,7 +44,10 @@ class YOLOv3(BaseAPI):
nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
nms_iou_threshold (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
label_smooth (bool): 是否使用label smooth。默认值为False。
train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
use_iou_loss (bool): 是否使用IoU Loss。默认为False。
use_iou_aware_loss (bool): 是否使用IoU Aware Loss。默认为False。
use_drop_block (bool): 是否使用DropBlock模块。默认为False。
use_dcn_v2 (bool): 是否使用Deformable Convolution v2(可变形卷积)。默认为False。
"""
def __init__(self,
......@@ -62,10 +65,7 @@ class YOLOv3(BaseAPI):
use_iou_aware_loss=False,
iou_aware_factor=0.4,
use_drop_block=False,
use_dcn_v2=False,
train_random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608
]):
use_dcn_v2=False):
self.init_params = locals()
super(YOLOv3, self).__init__('detector')
backbones = [
......@@ -89,7 +89,6 @@ class YOLOv3(BaseAPI):
self.iou_aware_factor = iou_aware_factor
self.use_drop_block = use_drop_block
self.use_dcn_v2 = use_dcn_v2
self.train_random_shapes = train_random_shapes
self.fixed_input_shape = None
if self.anchors is None:
self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
......@@ -139,13 +138,13 @@ class YOLOv3(BaseAPI):
nms_topk=self.nms_topk,
nms_keep_topk=self.nms_keep_topk,
nms_iou_threshold=self.nms_iou_threshold,
train_random_shapes=self.train_random_shapes,
fixed_input_shape=self.fixed_input_shape,
use_iou_loss=self.use_iou_loss,
use_iou_aware_loss=self.use_iou_aware_loss,
iou_aware_factor=self.iou_aware_factor,
use_drop_block=self.use_drop_block,
batch_size=self.train_batch_size if hasattr(self, 'train_batch_size') else 8)
batch_size=self.train_batch_size,
max_shape=self.max_shape)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict([('bbox', model_out)])
......@@ -254,15 +253,18 @@ class YOLOv3(BaseAPI):
if isinstance(transform, paddlex.det.transforms.Normalize):
transform.is_scale = False
if self.use_iou_loss or self.use_iou_aware_loss:
if self.train_random_shapes is None or len(self.train_random_shapes) == 0:
for transform in train_dataset.transforms.transforms:
if isinstance(transform, paddlex.det.transforms.Resize):
self.train_random_shapes = [transform.target_size]
self.max_shape = 0
for transform in train_dataset.transforms.transforms:
if isinstance(transform, paddlex.det.transforms.Resize):
self.max_shape = [transform.target_size]
break
if train_dataset.transforms.batch_transforms is None:
train_dataset.transforms.batch_transforms = []
else:
for bt in train_dataset.transforms.batch_transforms:
if isinstance(bt, paddlex.det.transforms.BatchRandomShape):
self.max_shape = max(bt.random_shapes)
break
train_dataset.transforms.batch_transforms = []
reshape_bt = paddlex.det.transforms.RandomShape
train_dataset.transforms.batch_transforms.append(reshape_bt(
random_shapes=self.train_random_shapes))
iou_bt = paddlex.det.transforms.GenerateYoloTarget
train_dataset.transforms.batch_transforms.append(iou_bt(anchors=self.anchors,
anchor_masks=self.anchor_masks,
......
......@@ -34,15 +34,17 @@ class YOLOv3:
nms_topk=1000,
nms_keep_topk=100,
nms_iou_threshold=0.45,
train_random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608
],
train_random_shapes="Deprecated",
fixed_input_shape=None,
use_iou_loss=False,
use_iou_aware_loss=False,
iou_aware_factor=0.4,
use_drop_block=False,
batch_size=8):
batch_size=8,
max_shape = 608):
if train_random_shapes != "Deprecated":
raise Exception("The 'train_random_shapes' is deprecated. If you want to set train_random_shapes, " \
+ "you can use BatchRandomShape. The details you can see ")
if anchors is None:
anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
......@@ -62,7 +64,6 @@ class YOLOv3:
self.nms_iou_threshold = nms_iou_threshold
self.norm_decay = 0.0
self.prefix_name = ''
self.train_random_shapes = train_random_shapes
self.fixed_input_shape = fixed_input_shape
self.use_iou_loss = use_iou_loss
self.use_iou_aware_loss = use_iou_aware_loss
......@@ -71,6 +72,7 @@ class YOLOv3:
self.block_size = 3
self.keep_prob = 0.9
self.batch_size = batch_size
self.max_shape = max_shape
def _head(self, feats):
outputs = []
......@@ -284,7 +286,7 @@ class YOLOv3:
return yolo_loss_obj(inputs, gt_box, gt_label, gt_score, targets,
self.anchors, self.anchor_masks,
self.mask_anchors, self.num_classes,
self.prefix_name, max(self.train_random_shapes))
self.prefix_name, self.max_shape)
def _get_prediction(self, inputs, im_size):
boxes = []
......
......@@ -221,3 +221,124 @@ def segms_horizontal_flip(segms, height, width):
import pycocotools.mask as mask_util
flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms
class GenerateYoloTarget(object):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def __init__(self,
anchors,
anchor_masks,
num_classes=80,
iou_thresh=1.):
super(GenerateYoloTarget, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.num_classes = num_classes
self.iou_thresh = iou_thresh
def __call__(self, batch_data):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im = batch_data[0][0]
h = im.shape[1]
w = im.shape[2]
an_hw = np.array(self.anchors) / np.array([[w, h]])
for data_id, data in enumerate(batch_data):
gt_bbox = data[1]
gt_class = data[2]
gt_score = data[3]
im_shape = data[4]
origin_h = float(im_shape[0])
origin_w = float(im_shape[1])
data_list = list(data)
for i, mask in enumerate(self.anchor_masks):
downsample_ratio = 32 // pow(2, i)
grid_h = int(h / downsample_ratio)
grid_w = int(w / downsample_ratio)
target = np.zeros(
(len(mask), 6 + self.num_classes, grid_h, grid_w),
dtype=np.float32)
for b in range(gt_bbox.shape[0]):
gx = gt_bbox[b, 0] / float(origin_w)
gy = gt_bbox[b, 1] / float(origin_h)
gw = gt_bbox[b, 2] / float(origin_w)
gh = gt_bbox[b, 3] / float(origin_h)
cls = gt_class[b]
score = gt_score[b]
if gw <= 0. or gh <= 0. or score <= 0.:
continue
# find best match anchor index
best_iou = 0.
best_idx = -1
for an_idx in range(an_hw.shape[0]):
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
if iou > best_iou:
best_iou = iou
best_idx = an_idx
gi = int(gx * grid_w)
gj = int(gy * grid_h)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if best_idx in mask:
best_n = mask.index(best_idx)
# x, y, w, h, scale
target[best_n, 0, gj, gi] = gx * grid_w - gi
target[best_n, 1, gj, gi] = gy * grid_h - gj
target[best_n, 2, gj, gi] = np.log(
gw * w / self.anchors[best_idx][0])
target[best_n, 3, gj, gi] = np.log(
gh * h / self.anchors[best_idx][1])
target[best_n, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[best_n, 5, gj, gi] = score
# classification
target[best_n, 6 + cls, gj, gi] = 1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if self.iou_thresh < 1:
for idx, mask_i in enumerate(mask):
if mask_i == best_idx: continue
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
if iou > self.iou_thresh:
# x, y, w, h, scale
target[idx, 0, gj, gi] = gx * grid_w - gi
target[idx, 1, gj, gi] = gy * grid_h - gj
target[idx, 2, gj, gi] = np.log(
gw * w / self.anchors[mask_i][0])
target[idx, 3, gj, gi] = np.log(
gh * h / self.anchors[mask_i][1])
target[idx, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[idx, 5, gj, gi] = score
# classification
target[idx, 6 + cls, gj, gi] = 1.
data_list.append(target)
batch_data[data_id] = tuple(data_list)
return batch_data
......@@ -1238,7 +1238,7 @@ class ArrangeYOLOv3(DetTransform):
return outputs
class RandomShape(DetTransform):
class BatchRandomShape(DetTransform):
"""调整图像大小(resize)。
对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
......@@ -1300,128 +1300,7 @@ class RandomShape(DetTransform):
data_list[0] = im
batch_data[data_id] = tuple(data_list)
np.save('im.npy', im)
return batch_data
class GenerateYoloTarget(DetTransform):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def __init__(self,
anchors,
anchor_masks,
num_classes=80,
iou_thresh=1.):
super(GenerateYoloTarget, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.num_classes = num_classes
self.iou_thresh = iou_thresh
def __call__(self, batch_data):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im = batch_data[0][0]
h = im.shape[1]
w = im.shape[2]
an_hw = np.array(self.anchors) / np.array([[w, h]])
for data_id, data in enumerate(batch_data):
gt_bbox = data[1]
gt_class = data[2]
gt_score = data[3]
im_shape = data[4]
origin_h = float(im_shape[0])
origin_w = float(im_shape[1])
data_list = list(data)
for i, mask in enumerate(self.anchor_masks):
downsample_ratio = 32 // pow(2, i)
grid_h = int(h / downsample_ratio)
grid_w = int(w / downsample_ratio)
target = np.zeros(
(len(mask), 6 + self.num_classes, grid_h, grid_w),
dtype=np.float32)
for b in range(gt_bbox.shape[0]):
gx = gt_bbox[b, 0] / float(origin_w)
gy = gt_bbox[b, 1] / float(origin_h)
gw = gt_bbox[b, 2] / float(origin_w)
gh = gt_bbox[b, 3] / float(origin_h)
cls = gt_class[b]
score = gt_score[b]
if gw <= 0. or gh <= 0. or score <= 0.:
continue
# find best match anchor index
best_iou = 0.
best_idx = -1
for an_idx in range(an_hw.shape[0]):
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
if iou > best_iou:
best_iou = iou
best_idx = an_idx
gi = int(gx * grid_w)
gj = int(gy * grid_h)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if best_idx in mask:
best_n = mask.index(best_idx)
# x, y, w, h, scale
target[best_n, 0, gj, gi] = gx * grid_w - gi
target[best_n, 1, gj, gi] = gy * grid_h - gj
target[best_n, 2, gj, gi] = np.log(
gw * w / self.anchors[best_idx][0])
target[best_n, 3, gj, gi] = np.log(
gh * h / self.anchors[best_idx][1])
target[best_n, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[best_n, 5, gj, gi] = score
# classification
target[best_n, 6 + cls, gj, gi] = 1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if self.iou_thresh < 1:
for idx, mask_i in enumerate(mask):
if mask_i == best_idx: continue
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
if iou > self.iou_thresh:
# x, y, w, h, scale
target[idx, 0, gj, gi] = gx * grid_w - gi
target[idx, 1, gj, gi] = gy * grid_h - gj
target[idx, 2, gj, gi] = np.log(
gw * w / self.anchors[mask_i][0])
target[idx, 3, gj, gi] = np.log(
gh * h / self.anchors[mask_i][1])
target[idx, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[idx, 5, gj, gi] = score
# classification
target[idx, 6 + cls, gj, gi] = 1.
data_list.append(target)
batch_data[data_id] = tuple(data_list)
return batch_data
return batch_data
class ComposedRCNNTransforms(Compose):
......@@ -1489,6 +1368,8 @@ class ComposedYOLOTransforms(Compose):
mixup_epoch(int): 模型训练过程中,前mixup_epoch会使用mixup策略
mean(list): 图像均值
std(list): 图像方差
random_shapes (list): resize大小选择列表。
默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
"""
def __init__(self,
......@@ -1496,7 +1377,10 @@ class ComposedYOLOTransforms(Compose):
shape=[608, 608],
mixup_epoch=250,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
std=[0.229, 0.224, 0.225],
random_shapes=[
320, 352, 384, 416, 448, 480, 512, 544, 576, 608
]):
width = shape
if isinstance(shape, list):
if shape[0] != shape[1]:
......@@ -1517,6 +1401,9 @@ class ComposedYOLOTransforms(Compose):
interp='RANDOM'), RandomHorizontalFlip(), Normalize(
mean=mean, std=std)
]
batch_transforms = [
BatchRandomShape(random_shapes=random_shapes)
]
else:
# 验证/预测时的transforms
transforms = [
......@@ -1524,4 +1411,4 @@ class ComposedYOLOTransforms(Compose):
target_size=width, interp='CUBIC'), Normalize(
mean=mean, std=std)
]
super(ComposedYOLOTransforms, self).__init__(transforms)
\ No newline at end of file
super(ComposedYOLOTransforms, self).__init__(transforms, batch_transforms)
\ No newline at end of file
......@@ -18,8 +18,9 @@ import numpy as np
from PIL import Image, ImageEnhance
def normalize(im, mean, std):
im = im / 255.0
def normalize(im, mean, std, is_scale=True):
if is_scale:
im = im / 255.0
im -= mean
im /= std
return im
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册