未验证 提交 7bb67db3 编写于 作者: 张春乔 提交者: GitHub

fix the div 0 errors in psroi_pool (#49965)

* fix the div 0 errors in psroi_pool

* fix case 7

* rool back sth.
上级 fb74147c
......@@ -339,6 +339,22 @@ class TestPSROIPoolChannelError(unittest.TestCase):
self.assertRaises(ValueError, test_channel_error)
class TestPSROIPoolZeroDivError(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.x = paddle.uniform([2, 490, 28, 28], dtype='float32')
self.boxes = paddle.to_tensor(
[[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]], dtype='float32'
)
self.boxes_num = paddle.to_tensor([1, 2], dtype='int32')
def test_errors(self):
def test_zero_div_error():
paddle.vision.ops.psroi_pool(self.x, self.boxes, self.boxes_num, 0)
self.assertRaises(ValueError, test_zero_div_error)
class TestPSROIPoolStaticAPI(unittest.TestCase):
def setUp(self):
paddle.enable_static()
......
......@@ -1424,6 +1424,8 @@ def psroi_pool(x, boxes, boxes_num, output_size, spatial_scale=1.0, name=None):
output_size = (output_size, output_size)
pooled_height, pooled_width = output_size
assert len(x.shape) == 4, "Input features with shape should be (N, C, H, W)"
if pooled_height * pooled_width == 0:
raise ValueError('output_size should not contain 0.')
output_channels = int(x.shape[1] / (pooled_height * pooled_width))
if in_dygraph_mode():
return _C_ops.psroi_pool(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册