diff --git a/ppdet/data/transform/gridmask_utils.py b/ppdet/data/transform/gridmask_utils.py index f0775dc3177b166b1a50ff4704b2d1edb0315ea2..b0a27f015443701e6f690b96101d3d33fa3fbaaa 100644 --- a/ppdet/data/transform/gridmask_utils.py +++ b/ppdet/data/transform/gridmask_utils.py @@ -45,8 +45,7 @@ class Gridmask(object): self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) if np.random.rand() > self.prob: return x - # image should be C, H, W format - _, h, w = x.shape + h, w, _ = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h) @@ -74,7 +73,7 @@ class Gridmask(object): if self.mode == 1: mask = 1 - mask - mask = np.expand_dims(mask, axis=0) + mask = np.expand_dims(mask, axis=-1) if self.offset: offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32) x = (x * mask + offset * (1 - mask)).astype(x.dtype) diff --git a/static/ppdet/data/transform/gridmask_utils.py b/static/ppdet/data/transform/gridmask_utils.py index b370bf0078a3f960fdf0d65cc7057153eff49b68..af1f8d56fd75e75271834de0cf10285a93177319 100644 --- a/static/ppdet/data/transform/gridmask_utils.py +++ b/static/ppdet/data/transform/gridmask_utils.py @@ -46,7 +46,7 @@ class GridMask(object): if np.random.rand() > self.prob: return x # image should be C, H, W format - _, h, w = x.shape + h, w, _ = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h) @@ -74,7 +74,7 @@ class GridMask(object): if self.mode == 1: mask = 1 - mask - mask = np.expand_dims(mask, axis=0) + mask = np.expand_dims(mask, axis=-1) if self.offset: offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32) x = (x * mask + offset * (1 - mask)).astype(x.dtype)