未验证 提交 2fb39e1d 编写于 作者: W wangxinxin08 提交者: GitHub

fix problem in GridMaskOp (#2440)

* fix problem in GridMaskOp

* fix problem in GridMask
上级 abc2e7f3
...@@ -45,7 +45,7 @@ class GridMask(object): ...@@ -45,7 +45,7 @@ class GridMask(object):
self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter)
if np.random.rand() > self.prob: if np.random.rand() > self.prob:
return x return x
_, h, w = x.shape h, w, _ = x.shape
hh = int(1.5 * h) hh = int(1.5 * h)
ww = int(1.5 * w) ww = int(1.5 * w)
d = np.random.randint(2, h) d = np.random.randint(2, h)
......
...@@ -626,7 +626,7 @@ class GridMaskOp(BaseOperator): ...@@ -626,7 +626,7 @@ class GridMaskOp(BaseOperator):
sample['curr_iter']) sample['curr_iter'])
if not batch_input: if not batch_input:
samples = samples[0] samples = samples[0]
return sample return samples
@register_op @register_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册