未验证 提交 fab94925 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph]rename GridMaskOp to GridMask in operator.py to avoid conflict (#1686)

* rename GridMaskOp to GridMask in ppdet/data/transform/operator.py to avoid conflict

* fix bugs in operator and batch_operator and check_version

* modify name in GridMask

* add comma according to review
上级 59f8b0fc
...@@ -32,9 +32,11 @@ logger = logging.getLogger(__name__) ...@@ -32,9 +32,11 @@ logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'PadBatchOp', 'PadBatchOp',
'BatchRandomResizeOp',
'Gt2YoloTargetOp', 'Gt2YoloTargetOp',
'Gt2FCOSTargetOp', 'Gt2FCOSTargetOp',
'Gt2TTFTargetOp', 'Gt2TTFTargetOp',
'Gt2Solov2TargetOp',
] ]
......
...@@ -224,7 +224,7 @@ class NormalizeImageOp(BaseOperator): ...@@ -224,7 +224,7 @@ class NormalizeImageOp(BaseOperator):
@register_op @register_op
class GridMaskOp(BaseOperator): class GridMask(BaseOperator):
def __init__(self, def __init__(self,
use_h=True, use_h=True,
use_w=True, use_w=True,
...@@ -246,7 +246,7 @@ class GridMaskOp(BaseOperator): ...@@ -246,7 +246,7 @@ class GridMaskOp(BaseOperator):
prob (float): max probability to carry out gridmask prob (float): max probability to carry out gridmask
upper_iter (int): suggested to be equal to global max_iter upper_iter (int): suggested to be equal to global max_iter
""" """
super(GridMaskOp, self).__init__() super(GridMask, self).__init__()
self.use_h = use_h self.use_h = use_h
self.use_w = use_w self.use_w = use_w
self.rotate = rotate self.rotate = rotate
...@@ -1545,13 +1545,13 @@ class NormalizeBoxOp(BaseOperator): ...@@ -1545,13 +1545,13 @@ class NormalizeBoxOp(BaseOperator):
@register_op @register_op
class BboxXYXY2XYWH(BaseOperator): class BboxXYXY2XYWHOp(BaseOperator):
""" """
Convert bbox XYXY format to XYWH format. Convert bbox XYXY format to XYWH format.
""" """
def __init__(self): def __init__(self):
super(BboxXYXY2XYWH, self).__init__() super(BboxXYXY2XYWHOp, self).__init__()
def apply(self, sample, context=None): def apply(self, sample, context=None):
assert 'gt_bbox' in sample assert 'gt_bbox' in sample
...@@ -1562,6 +1562,48 @@ class BboxXYXY2XYWH(BaseOperator): ...@@ -1562,6 +1562,48 @@ class BboxXYXY2XYWH(BaseOperator):
return sample return sample
@register_op
class PadBoxOp(BaseOperator):
def __init__(self, num_max_boxes=50):
"""
Pad zeros to bboxes if number of bboxes is less than num_max_boxes.
Args:
num_max_boxes (int): the max number of bboxes
"""
self.num_max_boxes = num_max_boxes
super(PadBoxOp, self).__init__()
def apply(self, sample, context=None):
assert 'gt_bbox' in sample
bbox = sample['gt_bbox']
gt_num = min(self.num_max_boxes, len(bbox))
num_max = self.num_max_boxes
fields = context['fields'] if context else []
pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
if gt_num > 0:
pad_bbox[:gt_num, :] = bbox[:gt_num, :]
sample['gt_bbox'] = pad_bbox
if 'gt_class' in fields:
pad_class = np.zeros((num_max), dtype=np.int32)
if gt_num > 0:
pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
sample['gt_class'] = pad_class
if 'gt_score' in fields:
pad_score = np.zeros((num_max), dtype=np.float32)
if gt_num > 0:
pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
sample['gt_score'] = pad_score
# in training, for example in op ExpandImage,
# the bbox and gt_class is expandded, but the difficult is not,
# so, judging by it's length
if 'is_difficult' in fields:
pad_diff = np.zeros((num_max), dtype=np.int32)
if gt_num > 0:
pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
sample['difficult'] = pad_diff
return sample
@register_op @register_op
class DebugVisibleImageOp(BaseOperator): class DebugVisibleImageOp(BaseOperator):
""" """
......
...@@ -65,13 +65,11 @@ def check_version(version='2.0'): ...@@ -65,13 +65,11 @@ def check_version(version='2.0'):
version_split = version.split('.') version_split = version.split('.')
length = min(len(version_installed), len(version_split)) length = min(len(version_installed), len(version_split))
flag = False
for i in six.moves.range(length): for i in six.moves.range(length):
if version_installed[i] > version_split[i]: if version_installed[i] > version_split[i]:
flag = True return
break if version_installed[i] < version_split[i]:
if not flag: raise Exception(err)
raise Exception(err)
def check_config(cfg): def check_config(cfg):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册