From 5b18edf5f9c35647f4ae76ba47dbf6173cab68f8 Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Thu, 28 May 2020 21:26:00 +0800 Subject: [PATCH] fix for review --- paddlex/cv/models/utils/pretrain_weights.py | 28 +++- paddlex/cv/models/yolo_v3.py | 34 ++--- paddlex/cv/nets/detection/yolo_v3.py | 14 +- paddlex/cv/transforms/box_utils.py | 121 +++++++++++++++++ paddlex/cv/transforms/det_transforms.py | 137 ++------------------ paddlex/cv/transforms/ops.py | 5 +- 6 files changed, 183 insertions(+), 156 deletions(-) diff --git a/paddlex/cv/models/utils/pretrain_weights.py b/paddlex/cv/models/utils/pretrain_weights.py index 38bdb54..a2d0a41 100644 --- a/paddlex/cv/models/utils/pretrain_weights.py +++ b/paddlex/cv/models/utils/pretrain_weights.py @@ -73,7 +73,7 @@ image_pretrain = { } obj365_pretrain = { - 'ResNet50_vd_dcn_db_obj365': + 'ResNet50_vd_obj365': 'https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar', } @@ -127,13 +127,27 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): if hasattr(paddlex, 'pretrain_dir'): new_save_dir = paddlex.pretrain_dir if backbone == 'ResNet50_vd': - backbone = 'ResNet50_vd_dcn_db_obj365' - assert backbone in obj365_pretrain, "There is not Object365 pretrain weights for {}, you may try ImageNet.".format( + backbone = 'ResNet50_vd_obj365' + assert backbone in obj365_pretrain, "There is not Object365 pretrain weights for {}, try use pretrain_weights='IMAGENET'".format( backbone) - url = obj365_pretrain[backbone] - fname = osp.split(url)[-1].split('.')[0] - paddlex.utils.download_and_decompress(url, path=new_save_dir) - return osp.join(new_save_dir, fname) +# url = obj365_pretrain[backbone] +# fname = osp.split(url)[-1].split('.')[0] +# paddlex.utils.download_and_decompress(url, path=new_save_dir) +# return osp.join(new_save_dir, fname) + try: + hub.download(backbone, save_path=new_save_dir) + except Exception as e: + if isinstance(hub.ResourceNotFoundError): + raise Exception("Resource for backbone {} not found".format( + backbone)) + elif isinstance(hub.ServerConnectionError): + raise Exception( + "Cannot get reource for backbone {}, please check your internet connecgtion" + .format(backbone)) + else: + raise Exception( + "Unexpected error, please make sure paddlehub >= 1.6.2") + return osp.join(new_save_dir, backbone) elif flag == 'COCO': new_save_dir = save_dir if hasattr(paddlex, 'pretrain_dir'): diff --git a/paddlex/cv/models/yolo_v3.py b/paddlex/cv/models/yolo_v3.py index d19a489..4ad13c0 100644 --- a/paddlex/cv/models/yolo_v3.py +++ b/paddlex/cv/models/yolo_v3.py @@ -44,7 +44,10 @@ class YOLOv3(BaseAPI): nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。 nms_iou_threshold (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。 label_smooth (bool): 是否使用label smooth。默认值为False。 - train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。 + use_iou_loss (bool): 是否使用IoU Loss。默认为False。 + use_iou_aware_loss (bool): 是否使用IoU Aware Loss。默认为False。 + use_drop_block (bool): 是否使用DropBlock模块。默认为False。 + use_dcn_v2 (bool): 是否使用Deformable Convolution v2(可变形卷积)。默认为False。 """ def __init__(self, @@ -62,10 +65,7 @@ class YOLOv3(BaseAPI): use_iou_aware_loss=False, iou_aware_factor=0.4, use_drop_block=False, - use_dcn_v2=False, - train_random_shapes=[ - 320, 352, 384, 416, 448, 480, 512, 544, 576, 608 - ]): + use_dcn_v2=False): self.init_params = locals() super(YOLOv3, self).__init__('detector') backbones = [ @@ -89,7 +89,6 @@ class YOLOv3(BaseAPI): self.iou_aware_factor = iou_aware_factor self.use_drop_block = use_drop_block self.use_dcn_v2 = use_dcn_v2 - self.train_random_shapes = train_random_shapes self.fixed_input_shape = None if self.anchors is None: self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], @@ -139,13 +138,13 @@ class YOLOv3(BaseAPI): nms_topk=self.nms_topk, nms_keep_topk=self.nms_keep_topk, nms_iou_threshold=self.nms_iou_threshold, - train_random_shapes=self.train_random_shapes, fixed_input_shape=self.fixed_input_shape, use_iou_loss=self.use_iou_loss, use_iou_aware_loss=self.use_iou_aware_loss, iou_aware_factor=self.iou_aware_factor, use_drop_block=self.use_drop_block, - batch_size=self.train_batch_size if hasattr(self, 'train_batch_size') else 8) + batch_size=self.train_batch_size, + max_shape=self.max_shape) inputs = model.generate_inputs() model_out = model.build_net(inputs) outputs = OrderedDict([('bbox', model_out)]) @@ -254,15 +253,18 @@ class YOLOv3(BaseAPI): if isinstance(transform, paddlex.det.transforms.Normalize): transform.is_scale = False if self.use_iou_loss or self.use_iou_aware_loss: - if self.train_random_shapes is None or len(self.train_random_shapes) == 0: - for transform in train_dataset.transforms.transforms: - if isinstance(transform, paddlex.det.transforms.Resize): - self.train_random_shapes = [transform.target_size] + self.max_shape = 0 + for transform in train_dataset.transforms.transforms: + if isinstance(transform, paddlex.det.transforms.Resize): + self.max_shape = [transform.target_size] + break + if train_dataset.transforms.batch_transforms is None: + train_dataset.transforms.batch_transforms = [] + else: + for bt in train_dataset.transforms.batch_transforms: + if isinstance(bt, paddlex.det.transforms.BatchRandomShape): + self.max_shape = max(bt.random_shapes) break - train_dataset.transforms.batch_transforms = [] - reshape_bt = paddlex.det.transforms.RandomShape - train_dataset.transforms.batch_transforms.append(reshape_bt( - random_shapes=self.train_random_shapes)) iou_bt = paddlex.det.transforms.GenerateYoloTarget train_dataset.transforms.batch_transforms.append(iou_bt(anchors=self.anchors, anchor_masks=self.anchor_masks, diff --git a/paddlex/cv/nets/detection/yolo_v3.py b/paddlex/cv/nets/detection/yolo_v3.py index b854120..9a18e42 100644 --- a/paddlex/cv/nets/detection/yolo_v3.py +++ b/paddlex/cv/nets/detection/yolo_v3.py @@ -34,15 +34,17 @@ class YOLOv3: nms_topk=1000, nms_keep_topk=100, nms_iou_threshold=0.45, - train_random_shapes=[ - 320, 352, 384, 416, 448, 480, 512, 544, 576, 608 - ], + train_random_shapes="Deprecated", fixed_input_shape=None, use_iou_loss=False, use_iou_aware_loss=False, iou_aware_factor=0.4, use_drop_block=False, - batch_size=8): + batch_size=8, + max_shape = 608): + if train_random_shapes != "Deprecated": + raise Exception("The 'train_random_shapes' is deprecated. If you want to set train_random_shapes, " \ + + "you can use BatchRandomShape. The details you can see ") if anchors is None: anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]] @@ -62,7 +64,6 @@ class YOLOv3: self.nms_iou_threshold = nms_iou_threshold self.norm_decay = 0.0 self.prefix_name = '' - self.train_random_shapes = train_random_shapes self.fixed_input_shape = fixed_input_shape self.use_iou_loss = use_iou_loss self.use_iou_aware_loss = use_iou_aware_loss @@ -71,6 +72,7 @@ class YOLOv3: self.block_size = 3 self.keep_prob = 0.9 self.batch_size = batch_size + self.max_shape = max_shape def _head(self, feats): outputs = [] @@ -284,7 +286,7 @@ class YOLOv3: return yolo_loss_obj(inputs, gt_box, gt_label, gt_score, targets, self.anchors, self.anchor_masks, self.mask_anchors, self.num_classes, - self.prefix_name, max(self.train_random_shapes)) + self.prefix_name, self.max_shape) def _get_prediction(self, inputs, im_size): boxes = [] diff --git a/paddlex/cv/transforms/box_utils.py b/paddlex/cv/transforms/box_utils.py index 02f3c4d..37c7700 100644 --- a/paddlex/cv/transforms/box_utils.py +++ b/paddlex/cv/transforms/box_utils.py @@ -221,3 +221,124 @@ def segms_horizontal_flip(segms, height, width): import pycocotools.mask as mask_util flipped_segms.append(_flip_rle(segm, height, width)) return flipped_segms + + +class GenerateYoloTarget(object): + """生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。 + 该transform只在YOLOv3计算细粒度loss时使用。 + + Args: + anchors (list|tuple): anchor框的宽度和高度。 + anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。 + num_classes (int): 类别数。默认为80。 + iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。 + """ + + def __init__(self, + anchors, + anchor_masks, + num_classes=80, + iou_thresh=1.): + super(GenerateYoloTarget, self).__init__() + self.anchors = anchors + self.anchor_masks = anchor_masks + self.num_classes = num_classes + self.iou_thresh = iou_thresh + + def __call__(self, batch_data): + """ + Args: + batch_data (list): 由与图像相关的各种信息组成的batch数据。 + + Returns: + list: 由与图像相关的各种信息组成的batch数据。 + 其中,每个数据新添加的字段为: + - target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息, + 形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。 + - target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息, + 形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。 + - ... + -targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息, + 形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。 + n的是大小由anchor_masks的长度决定。 + """ + im = batch_data[0][0] + h = im.shape[1] + w = im.shape[2] + an_hw = np.array(self.anchors) / np.array([[w, h]]) + for data_id, data in enumerate(batch_data): + gt_bbox = data[1] + gt_class = data[2] + gt_score = data[3] + im_shape = data[4] + origin_h = float(im_shape[0]) + origin_w = float(im_shape[1]) + data_list = list(data) + for i, mask in enumerate(self.anchor_masks): + downsample_ratio = 32 // pow(2, i) + grid_h = int(h / downsample_ratio) + grid_w = int(w / downsample_ratio) + target = np.zeros( + (len(mask), 6 + self.num_classes, grid_h, grid_w), + dtype=np.float32) + for b in range(gt_bbox.shape[0]): + gx = gt_bbox[b, 0] / float(origin_w) + gy = gt_bbox[b, 1] / float(origin_h) + gw = gt_bbox[b, 2] / float(origin_w) + gh = gt_bbox[b, 3] / float(origin_h) + cls = gt_class[b] + score = gt_score[b] + if gw <= 0. or gh <= 0. or score <= 0.: + continue + # find best match anchor index + best_iou = 0. + best_idx = -1 + for an_idx in range(an_hw.shape[0]): + iou = jaccard_overlap( + [0., 0., gw, gh], + [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]]) + if iou > best_iou: + best_iou = iou + best_idx = an_idx + gi = int(gx * grid_w) + gj = int(gy * grid_h) + # gtbox should be regresed in this layes if best match + # anchor index in anchor mask of this layer + if best_idx in mask: + best_n = mask.index(best_idx) + # x, y, w, h, scale + target[best_n, 0, gj, gi] = gx * grid_w - gi + target[best_n, 1, gj, gi] = gy * grid_h - gj + target[best_n, 2, gj, gi] = np.log( + gw * w / self.anchors[best_idx][0]) + target[best_n, 3, gj, gi] = np.log( + gh * h / self.anchors[best_idx][1]) + target[best_n, 4, gj, gi] = 2.0 - gw * gh + # objectness record gt_score + target[best_n, 5, gj, gi] = score + # classification + target[best_n, 6 + cls, gj, gi] = 1. + # For non-matched anchors, calculate the target if the iou + # between anchor and gt is larger than iou_thresh + if self.iou_thresh < 1: + for idx, mask_i in enumerate(mask): + if mask_i == best_idx: continue + iou = jaccard_overlap( + [0., 0., gw, gh], + [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]]) + if iou > self.iou_thresh: + # x, y, w, h, scale + target[idx, 0, gj, gi] = gx * grid_w - gi + target[idx, 1, gj, gi] = gy * grid_h - gj + target[idx, 2, gj, gi] = np.log( + gw * w / self.anchors[mask_i][0]) + target[idx, 3, gj, gi] = np.log( + gh * h / self.anchors[mask_i][1]) + target[idx, 4, gj, gi] = 2.0 - gw * gh + # objectness record gt_score + target[idx, 5, gj, gi] = score + # classification + target[idx, 6 + cls, gj, gi] = 1. + data_list.append(target) + batch_data[data_id] = tuple(data_list) + return batch_data diff --git a/paddlex/cv/transforms/det_transforms.py b/paddlex/cv/transforms/det_transforms.py index a72a920..0f944e2 100644 --- a/paddlex/cv/transforms/det_transforms.py +++ b/paddlex/cv/transforms/det_transforms.py @@ -1238,7 +1238,7 @@ class ArrangeYOLOv3(DetTransform): return outputs -class RandomShape(DetTransform): +class BatchRandomShape(DetTransform): """调整图像大小(resize)。 对batch数据中的每张图像全部resize到random_shapes中任意一个大小。 @@ -1300,128 +1300,7 @@ class RandomShape(DetTransform): data_list[0] = im batch_data[data_id] = tuple(data_list) np.save('im.npy', im) - return batch_data - - -class GenerateYoloTarget(DetTransform): - """生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。 - 该transform只在YOLOv3计算细粒度loss时使用。 - - Args: - anchors (list|tuple): anchor框的宽度和高度。 - anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。 - num_classes (int): 类别数。默认为80。 - iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。 - """ - - def __init__(self, - anchors, - anchor_masks, - num_classes=80, - iou_thresh=1.): - super(GenerateYoloTarget, self).__init__() - self.anchors = anchors - self.anchor_masks = anchor_masks - self.num_classes = num_classes - self.iou_thresh = iou_thresh - - def __call__(self, batch_data): - """ - Args: - batch_data (list): 由与图像相关的各种信息组成的batch数据。 - - Returns: - list: 由与图像相关的各种信息组成的batch数据。 - 其中,每个数据新添加的字段为: - - target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息, - 形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。 - - target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息, - 形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。 - - ... - -targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息, - 形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。 - n的是大小由anchor_masks的长度决定。 - """ - im = batch_data[0][0] - h = im.shape[1] - w = im.shape[2] - an_hw = np.array(self.anchors) / np.array([[w, h]]) - for data_id, data in enumerate(batch_data): - gt_bbox = data[1] - gt_class = data[2] - gt_score = data[3] - im_shape = data[4] - origin_h = float(im_shape[0]) - origin_w = float(im_shape[1]) - data_list = list(data) - for i, mask in enumerate(self.anchor_masks): - downsample_ratio = 32 // pow(2, i) - grid_h = int(h / downsample_ratio) - grid_w = int(w / downsample_ratio) - target = np.zeros( - (len(mask), 6 + self.num_classes, grid_h, grid_w), - dtype=np.float32) - for b in range(gt_bbox.shape[0]): - gx = gt_bbox[b, 0] / float(origin_w) - gy = gt_bbox[b, 1] / float(origin_h) - gw = gt_bbox[b, 2] / float(origin_w) - gh = gt_bbox[b, 3] / float(origin_h) - cls = gt_class[b] - score = gt_score[b] - if gw <= 0. or gh <= 0. or score <= 0.: - continue - # find best match anchor index - best_iou = 0. - best_idx = -1 - for an_idx in range(an_hw.shape[0]): - iou = jaccard_overlap( - [0., 0., gw, gh], - [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]]) - if iou > best_iou: - best_iou = iou - best_idx = an_idx - gi = int(gx * grid_w) - gj = int(gy * grid_h) - # gtbox should be regresed in this layes if best match - # anchor index in anchor mask of this layer - if best_idx in mask: - best_n = mask.index(best_idx) - # x, y, w, h, scale - target[best_n, 0, gj, gi] = gx * grid_w - gi - target[best_n, 1, gj, gi] = gy * grid_h - gj - target[best_n, 2, gj, gi] = np.log( - gw * w / self.anchors[best_idx][0]) - target[best_n, 3, gj, gi] = np.log( - gh * h / self.anchors[best_idx][1]) - target[best_n, 4, gj, gi] = 2.0 - gw * gh - # objectness record gt_score - target[best_n, 5, gj, gi] = score - # classification - target[best_n, 6 + cls, gj, gi] = 1. - # For non-matched anchors, calculate the target if the iou - # between anchor and gt is larger than iou_thresh - if self.iou_thresh < 1: - for idx, mask_i in enumerate(mask): - if mask_i == best_idx: continue - iou = jaccard_overlap( - [0., 0., gw, gh], - [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]]) - if iou > self.iou_thresh: - # x, y, w, h, scale - target[idx, 0, gj, gi] = gx * grid_w - gi - target[idx, 1, gj, gi] = gy * grid_h - gj - target[idx, 2, gj, gi] = np.log( - gw * w / self.anchors[mask_i][0]) - target[idx, 3, gj, gi] = np.log( - gh * h / self.anchors[mask_i][1]) - target[idx, 4, gj, gi] = 2.0 - gw * gh - # objectness record gt_score - target[idx, 5, gj, gi] = score - # classification - target[idx, 6 + cls, gj, gi] = 1. - data_list.append(target) - batch_data[data_id] = tuple(data_list) - return batch_data + return batch_data class ComposedRCNNTransforms(Compose): @@ -1489,6 +1368,8 @@ class ComposedYOLOTransforms(Compose): mixup_epoch(int): 模型训练过程中,前mixup_epoch会使用mixup策略 mean(list): 图像均值 std(list): 图像方差 + random_shapes (list): resize大小选择列表。 + 默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。 """ def __init__(self, @@ -1496,7 +1377,10 @@ class ComposedYOLOTransforms(Compose): shape=[608, 608], mixup_epoch=250, mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]): + std=[0.229, 0.224, 0.225], + random_shapes=[ + 320, 352, 384, 416, 448, 480, 512, 544, 576, 608 + ]): width = shape if isinstance(shape, list): if shape[0] != shape[1]: @@ -1517,6 +1401,9 @@ class ComposedYOLOTransforms(Compose): interp='RANDOM'), RandomHorizontalFlip(), Normalize( mean=mean, std=std) ] + batch_transforms = [ + BatchRandomShape(random_shapes=random_shapes) + ] else: # 验证/预测时的transforms transforms = [ @@ -1524,4 +1411,4 @@ class ComposedYOLOTransforms(Compose): target_size=width, interp='CUBIC'), Normalize( mean=mean, std=std) ] - super(ComposedYOLOTransforms, self).__init__(transforms) \ No newline at end of file + super(ComposedYOLOTransforms, self).__init__(transforms, batch_transforms) \ No newline at end of file diff --git a/paddlex/cv/transforms/ops.py b/paddlex/cv/transforms/ops.py index dd517d4..6191875 100644 --- a/paddlex/cv/transforms/ops.py +++ b/paddlex/cv/transforms/ops.py @@ -18,8 +18,9 @@ import numpy as np from PIL import Image, ImageEnhance -def normalize(im, mean, std): - im = im / 255.0 +def normalize(im, mean, std, is_scale=True): + if is_scale: + im = im / 255.0 im -= mean im /= std return im -- GitLab