From 67e0d761e1afa8d274d5b3a75fbe3ccccf49244a Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Mon, 19 Apr 2021 13:21:39 +0800 Subject: [PATCH] fix gridmask (#2694) --- ppdet/data/transform/gridmask_utils.py | 5 ++--- static/ppdet/data/transform/gridmask_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ppdet/data/transform/gridmask_utils.py b/ppdet/data/transform/gridmask_utils.py index f0775dc31..b0a27f015 100644 --- a/ppdet/data/transform/gridmask_utils.py +++ b/ppdet/data/transform/gridmask_utils.py @@ -45,8 +45,7 @@ class Gridmask(object): self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) if np.random.rand() > self.prob: return x - # image should be C, H, W format - _, h, w = x.shape + h, w, _ = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h) @@ -74,7 +73,7 @@ class Gridmask(object): if self.mode == 1: mask = 1 - mask - mask = np.expand_dims(mask, axis=0) + mask = np.expand_dims(mask, axis=-1) if self.offset: offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32) x = (x * mask + offset * (1 - mask)).astype(x.dtype) diff --git a/static/ppdet/data/transform/gridmask_utils.py b/static/ppdet/data/transform/gridmask_utils.py index b370bf007..af1f8d56f 100644 --- a/static/ppdet/data/transform/gridmask_utils.py +++ b/static/ppdet/data/transform/gridmask_utils.py @@ -46,7 +46,7 @@ class GridMask(object): if np.random.rand() > self.prob: return x # image should be C, H, W format - _, h, w = x.shape + h, w, _ = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h) @@ -74,7 +74,7 @@ class GridMask(object): if self.mode == 1: mask = 1 - mask - mask = np.expand_dims(mask, axis=0) + mask = np.expand_dims(mask, axis=-1) if self.offset: offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32) x = (x * mask + offset * (1 - mask)).astype(x.dtype) -- GitLab