cutmix的实现是否已完成?
Created by: Edwardwaw
class CutmixImage(BaseOperator): def init(self, alpha=1.5, beta=1.5): """ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, see https://https://arxiv.org/abs/1905.04899 Cutmix image and gt_bbbox/gt_score Args: alpha (float): alpha parameter of beta distribute beta (float): beta parameter of beta distribute """ super(CutmixImage, self).init() self.alpha = alpha self.beta = beta if self.alpha <= 0.0: raise ValueError("alpha shold be positive in {}".format(self)) if self.beta <= 0.0: raise ValueError("beta shold be positive in {}".format(self))
def _rand_bbox(self, img1, img2, factor):
""" _rand_bbox """
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
cut_rat = np.sqrt(1. - factor)
cut_w = np.int(w * cut_rat)
cut_h = np.int(h * cut_rat)
# uniform
cx = np.random.randint(w)
cy = np.random.randint(h)
bbx1 = np.clip(cx - cut_w // 2, 0, w)
bby1 = np.clip(cy - cut_h // 2, 0, h)
bbx2 = np.clip(cx + cut_w // 2, 0, w)
bby2 = np.clip(cy + cut_h // 2, 0, h)
img_1 = np.zeros((h, w, img1.shape[2]), 'float32')
img_1[:img1.shape[0], :img1.shape[1], :] = img1.astype('float32')
img_2 = np.zeros((h, w, img2.shape[2]), 'float32')
img_2[:img2.shape[0], :img2.shape[1], :] = img2.astype('float32')
img_1[bby1:bby2, bbx1:bbx2, :] = img_2[bby1:bby2, bbx1:bbx2, :]
print('debug', bbx1, bby1, bbx2, bby2)
print('debug', img1.shape, img2.shape)
return img_1
def __call__(self, sample, context=None):
if 'cutmix' not in sample:
return sample
factor = np.random.beta(self.alpha, self.beta)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
sample.pop('cutmix')
return sample
if factor <= 0.0:
return sample['cutmix']
img1 = sample['image']
img2 = sample['cutmix']['image']
img = self._rand_bbox(img1, img2, factor)
gt_bbox1 = sample['gt_bbox']
gt_bbox2 = sample['cutmix']['gt_bbox']
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class1 = sample['gt_class']
gt_class2 = sample['cutmix']['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
gt_score1 = sample['gt_score']
gt_score2 = sample['cutmix']['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
sample['image'] = img
sample['gt_bbox'] = gt_bbox
sample['gt_score'] = gt_score
sample['gt_class'] = gt_class
sample['h'] = img.shape[0]
sample['w'] = img.shape[1]
sample.pop('cutmix')
return sample
请问当前版本的repo中,cutmix的效果是否经过检验?直觉上来说该部分对gt score gt box的处理似乎有些不妥