From 2fb39e1d2b9ff565e5c8c5173f32f8d8d03917b3 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Tue, 6 Apr 2021 10:22:58 +0800 Subject: [PATCH] fix problem in GridMaskOp (#2440) * fix problem in GridMaskOp * fix problem in GridMask --- ppdet/data/transform/gridmask_utils.py | 2 +- ppdet/data/transform/operators.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ppdet/data/transform/gridmask_utils.py b/ppdet/data/transform/gridmask_utils.py index a23e69b20..115cb1e9d 100644 --- a/ppdet/data/transform/gridmask_utils.py +++ b/ppdet/data/transform/gridmask_utils.py @@ -45,7 +45,7 @@ class GridMask(object): self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) if np.random.rand() > self.prob: return x - _, h, w = x.shape + h, w, _ = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h) diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 4646e2582..9a575940b 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -626,7 +626,7 @@ class GridMaskOp(BaseOperator): sample['curr_iter']) if not batch_input: samples = samples[0] - return sample + return samples @register_op -- GitLab