未验证 提交 98693428 编写于 作者: F FlyingQianMM 提交者: GitHub

[Zero-Dim] support input 0D Tensor for maximum,minimum,allclose,sigmoid_focal_loss (#49616)

* [Zero-Dim] support input 0D Tensor for maximum,minimum,allclose,sigmoid_focal_loss

* [Zero-Dim] add backward test for sigmoid_focal_loss with 0-D input Tensor
上级 72b2e486
......@@ -86,7 +86,13 @@ void AllValueCompareInferMeta(const MetaTensor& x,
MetaConfig config) {
detail::BinarySameInputDimsCheck(x, y, config);
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() == 0 && y_dims.size() == 0) {
out->set_dims(phi::make_ddim({}));
} else {
out->set_dims(phi::make_ddim({1}));
}
out->set_dtype(DataType::BOOL);
}
......
......@@ -283,6 +283,8 @@ binary_api_list = [
paddle.logical_and,
paddle.logical_or,
paddle.logical_xor,
paddle.maximum,
paddle.minimum,
]
binary_int_api_list = [
......@@ -994,6 +996,35 @@ class TestSundryAPI(unittest.TestCase):
# check grad shape with 1D repeats
self.assertEqual(x.grad.shape, [])
def test_sigmoid_focal_loss(self):
logit = paddle.to_tensor(
[[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]],
dtype='float32',
stop_gradient=False,
)
label = paddle.to_tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32'
)
fg_num_0 = paddle.full([], 2.0)
fg_num_1 = paddle.full([1], 2.0)
out0 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_0)
out1 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_1)
np.testing.assert_array_equal(
out0.numpy(),
out1.numpy(),
)
out0.backward()
self.assertEqual(out0.grad.shape, [1])
self.assertEqual(logit.grad.shape, [2, 3])
def test_allclose(self):
x = paddle.full([], 0.5)
y = paddle.full([], 0.6)
self.assertFalse(paddle.allclose(x, y))
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......
......@@ -178,6 +178,8 @@ binary_api_list = [
paddle.logical_and,
paddle.logical_or,
paddle.logical_xor,
paddle.maximum,
paddle.minimum,
]
binary_int_api_list = [
......@@ -703,6 +705,35 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 0)
def test_sigmoid_focal_loss(self):
logit = paddle.to_tensor(
[[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]],
dtype='float32',
stop_gradient=False,
)
label = paddle.to_tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32'
)
fg_num_0 = paddle.full([], 2.0)
fg_num_1 = paddle.full([1], 2.0)
out0 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_0)
out1 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_1)
np.testing.assert_array_equal(
out0.numpy(),
out1.numpy(),
)
out0.backward()
self.assertEqual(out0.grad.shape, [1])
self.assertEqual(logit.grad.shape, [2, 3])
def test_allclose(self):
x = paddle.full([], 0.5)
y = paddle.full([], 0.6)
self.assertFalse(paddle.allclose(x, y))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
......@@ -2882,8 +2882,8 @@ def sigmoid_focal_loss(
``logit``. The target label whose value should be numbers between 0 and 1.
Available dtype is float32, float64.
normalizer (Tensor, optional): The number normalizes the focal loss. It has to be
a 1-D Tensor whose shape is `[1, ]`. The data type is float32, float64.
For object detection task, it is the number of positive samples.
a 1-D Tensor with shape `[1, ]` or 0-D Tensor with shape `[]`. The data type
is float32, float64. For object detection task, it is the number of positive samples.
If set to None, the focal loss will not be normalized. Default is None.
alpha(int|float, optional): Hyper-parameter to balance the positive and negative example,
it should be between 0 and 1. Default value is set to 0.25.
......@@ -2934,7 +2934,7 @@ def sigmoid_focal_loss(
normalizer_dims = len(normalizer_shape)
if normalizer_dims > 1:
raise ValueError(
"Expected one dimension of normalizer in sigmoid_focal_loss but got {}.".format(
"Expected zero or one dimension of normalizer in sigmoid_focal_loss but got {}.".format(
normalizer_dims
)
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册