From fab949253ba90e541ea7412a40edec7165e3eff8 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Tue, 10 Nov 2020 17:45:45 +0800 Subject: [PATCH] [Dygraph]rename GridMaskOp to GridMask in operator.py to avoid conflict (#1686) * rename GridMaskOp to GridMask in ppdet/data/transform/operator.py to avoid conflict * fix bugs in operator and batch_operator and check_version * modify name in GridMask * add comma according to review --- ppdet/data/transform/batch_operator.py | 2 ++ ppdet/data/transform/operator.py | 50 +++++++++++++++++++++++--- ppdet/utils/check.py | 8 ++--- 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/ppdet/data/transform/batch_operator.py b/ppdet/data/transform/batch_operator.py index 09d273bfc..e02a7d6e4 100644 --- a/ppdet/data/transform/batch_operator.py +++ b/ppdet/data/transform/batch_operator.py @@ -32,9 +32,11 @@ logger = logging.getLogger(__name__) __all__ = [ 'PadBatchOp', + 'BatchRandomResizeOp', 'Gt2YoloTargetOp', 'Gt2FCOSTargetOp', 'Gt2TTFTargetOp', + 'Gt2Solov2TargetOp', ] diff --git a/ppdet/data/transform/operator.py b/ppdet/data/transform/operator.py index 83bc1064d..76d399be0 100644 --- a/ppdet/data/transform/operator.py +++ b/ppdet/data/transform/operator.py @@ -224,7 +224,7 @@ class NormalizeImageOp(BaseOperator): @register_op -class GridMaskOp(BaseOperator): +class GridMask(BaseOperator): def __init__(self, use_h=True, use_w=True, @@ -246,7 +246,7 @@ class GridMaskOp(BaseOperator): prob (float): max probability to carry out gridmask upper_iter (int): suggested to be equal to global max_iter """ - super(GridMaskOp, self).__init__() + super(GridMask, self).__init__() self.use_h = use_h self.use_w = use_w self.rotate = rotate @@ -1545,13 +1545,13 @@ class NormalizeBoxOp(BaseOperator): @register_op -class BboxXYXY2XYWH(BaseOperator): +class BboxXYXY2XYWHOp(BaseOperator): """ Convert bbox XYXY format to XYWH format. """ def __init__(self): - super(BboxXYXY2XYWH, self).__init__() + super(BboxXYXY2XYWHOp, self).__init__() def apply(self, sample, context=None): assert 'gt_bbox' in sample @@ -1562,6 +1562,48 @@ class BboxXYXY2XYWH(BaseOperator): return sample +@register_op +class PadBoxOp(BaseOperator): + def __init__(self, num_max_boxes=50): + """ + Pad zeros to bboxes if number of bboxes is less than num_max_boxes. + Args: + num_max_boxes (int): the max number of bboxes + """ + self.num_max_boxes = num_max_boxes + super(PadBoxOp, self).__init__() + + def apply(self, sample, context=None): + assert 'gt_bbox' in sample + bbox = sample['gt_bbox'] + gt_num = min(self.num_max_boxes, len(bbox)) + num_max = self.num_max_boxes + fields = context['fields'] if context else [] + pad_bbox = np.zeros((num_max, 4), dtype=np.float32) + if gt_num > 0: + pad_bbox[:gt_num, :] = bbox[:gt_num, :] + sample['gt_bbox'] = pad_bbox + if 'gt_class' in fields: + pad_class = np.zeros((num_max), dtype=np.int32) + if gt_num > 0: + pad_class[:gt_num] = sample['gt_class'][:gt_num, 0] + sample['gt_class'] = pad_class + if 'gt_score' in fields: + pad_score = np.zeros((num_max), dtype=np.float32) + if gt_num > 0: + pad_score[:gt_num] = sample['gt_score'][:gt_num, 0] + sample['gt_score'] = pad_score + # in training, for example in op ExpandImage, + # the bbox and gt_class is expandded, but the difficult is not, + # so, judging by it's length + if 'is_difficult' in fields: + pad_diff = np.zeros((num_max), dtype=np.int32) + if gt_num > 0: + pad_diff[:gt_num] = sample['difficult'][:gt_num, 0] + sample['difficult'] = pad_diff + return sample + + @register_op class DebugVisibleImageOp(BaseOperator): """ diff --git a/ppdet/utils/check.py b/ppdet/utils/check.py index c3e019595..382f99ed3 100644 --- a/ppdet/utils/check.py +++ b/ppdet/utils/check.py @@ -65,13 +65,11 @@ def check_version(version='2.0'): version_split = version.split('.') length = min(len(version_installed), len(version_split)) - flag = False for i in six.moves.range(length): if version_installed[i] > version_split[i]: - flag = True - break - if not flag: - raise Exception(err) + return + if version_installed[i] < version_split[i]: + raise Exception(err) def check_config(cfg): -- GitLab