diff --git a/ppdet/data/transform/gridmask_utils.py b/ppdet/data/transform/gridmask_utils.py index a23e69b20860fe90c7a25472e11de770d238dd07..115cb1e9d291be365ce0436373c9e5eac00acb4c 100644 --- a/ppdet/data/transform/gridmask_utils.py +++ b/ppdet/data/transform/gridmask_utils.py @@ -45,7 +45,7 @@ class GridMask(object): self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) if np.random.rand() > self.prob: return x - _, h, w = x.shape + h, w, _ = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h) diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 4646e2582146d0ebde5d19668a069f4e6907dcd0..9a575940b40a452afac07e35fb67f7bfba715928 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -626,7 +626,7 @@ class GridMaskOp(BaseOperator): sample['curr_iter']) if not batch_input: samples = samples[0] - return sample + return samples @register_op