diff --git a/ppdet/data/transform/gridmask_utils.py b/ppdet/data/transform/gridmask_utils.py index 2b3e72408c24f0b32559eb30cc8a451fecdb9d00..f0775dc3177b166b1a50ff4704b2d1edb0315ea2 100644 --- a/ppdet/data/transform/gridmask_utils.py +++ b/ppdet/data/transform/gridmask_utils.py @@ -45,7 +45,8 @@ class Gridmask(object): self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) if np.random.rand() > self.prob: return x - h, w, _ = x.shape + # image should be C, H, W format + _, h, w = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h) diff --git a/static/ppdet/data/transform/gridmask_utils.py b/static/ppdet/data/transform/gridmask_utils.py index 115cb1e9d291be365ce0436373c9e5eac00acb4c..b370bf0078a3f960fdf0d65cc7057153eff49b68 100644 --- a/static/ppdet/data/transform/gridmask_utils.py +++ b/static/ppdet/data/transform/gridmask_utils.py @@ -45,7 +45,8 @@ class GridMask(object): self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) if np.random.rand() > self.prob: return x - h, w, _ = x.shape + # image should be C, H, W format + _, h, w = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h)