From ecc33160dc7fa54ecc0f9aeb9fc9f50cfab61c74 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Mon, 19 Apr 2021 11:27:21 +0800 Subject: [PATCH] modify gridmask op (#2692) --- ppdet/data/transform/gridmask_utils.py | 3 ++- static/ppdet/data/transform/gridmask_utils.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ppdet/data/transform/gridmask_utils.py b/ppdet/data/transform/gridmask_utils.py index 2b3e72408..f0775dc31 100644 --- a/ppdet/data/transform/gridmask_utils.py +++ b/ppdet/data/transform/gridmask_utils.py @@ -45,7 +45,8 @@ class Gridmask(object): self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) if np.random.rand() > self.prob: return x - h, w, _ = x.shape + # image should be C, H, W format + _, h, w = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h) diff --git a/static/ppdet/data/transform/gridmask_utils.py b/static/ppdet/data/transform/gridmask_utils.py index 115cb1e9d..b370bf007 100644 --- a/static/ppdet/data/transform/gridmask_utils.py +++ b/static/ppdet/data/transform/gridmask_utils.py @@ -45,7 +45,8 @@ class GridMask(object): self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) if np.random.rand() > self.prob: return x - h, w, _ = x.shape + # image should be C, H, W format + _, h, w = x.shape hh = int(1.5 * h) ww = int(1.5 * w) d = np.random.randint(2, h) -- GitLab