未验证 提交 dc85b393 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support input 0D Tensor for distribution transform api (#47677)

* [Zero-Dim] support input 0D Tensor for distribution api

* fix comment
上级 047971f0
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
import enum import enum
import functools
import math import math
import operator
import typing import typing
import paddle import paddle
...@@ -401,7 +399,7 @@ class AbsTransform(Transform): ...@@ -401,7 +399,7 @@ class AbsTransform(Transform):
return -y, y return -y, y
def _inverse_log_det_jacobian(self, y): def _inverse_log_det_jacobian(self, y):
zero = paddle.zeros([1], dtype=y.dtype) zero = paddle.zeros([], dtype=y.dtype)
return zero, zero return zero, zero
@property @property
...@@ -872,12 +870,16 @@ class ReshapeTransform(Transform): ...@@ -872,12 +870,16 @@ class ReshapeTransform(Transform):
f"Squence[int], but got 'in_event_shape': {in_event_shape}, " f"Squence[int], but got 'in_event_shape': {in_event_shape}, "
f"'out_event_shape': {out_event_shape}" f"'out_event_shape': {out_event_shape}"
) )
if functools.reduce(operator.mul, in_event_shape) != functools.reduce( in_size = 1
operator.mul, out_event_shape for e in in_event_shape:
): in_size *= e
out_size = 1
for e in out_event_shape:
out_size *= e
if in_size != out_size:
raise ValueError( raise ValueError(
f"The numel of 'in_event_shape' should be 'out_event_shape', " f"The numel of 'in_event_shape' should be 'out_event_shape', "
f"but got {functools.reduce(operator.mul, in_event_shape)}!={functools.reduce(operator.mul, out_event_shape)}" f"but got {in_size}!={out_size}"
) )
self._in_event_shape = tuple(in_event_shape) self._in_event_shape = tuple(in_event_shape)
...@@ -917,7 +919,9 @@ class ReshapeTransform(Transform): ...@@ -917,7 +919,9 @@ class ReshapeTransform(Transform):
raise ValueError( raise ValueError(
f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}" f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}"
) )
if shape[-len(self._in_event_shape) :] != self._in_event_shape: if tuple(shape[-len(self._in_event_shape) :]) != tuple(
self._in_event_shape
):
raise ValueError( raise ValueError(
f"Event shape mismatch, expected: {self._in_event_shape}, but got {shape[-len(self._in_event_shape):]}" f"Event shape mismatch, expected: {self._in_event_shape}, but got {shape[-len(self._in_event_shape):]}"
) )
...@@ -930,7 +934,9 @@ class ReshapeTransform(Transform): ...@@ -930,7 +934,9 @@ class ReshapeTransform(Transform):
raise ValueError( raise ValueError(
f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}" f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}"
) )
if shape[-len(self._out_event_shape) :] != self._out_event_shape: if tuple(shape[-len(self._out_event_shape) :]) != tuple(
self._out_event_shape
):
raise ValueError( raise ValueError(
f"Event shape mismatch, expected: {self._out_event_shape}, but got {shape[-len(self._out_event_shape):]}" f"Event shape mismatch, expected: {self._out_event_shape}, but got {shape[-len(self._out_event_shape):]}"
) )
...@@ -939,7 +945,7 @@ class ReshapeTransform(Transform): ...@@ -939,7 +945,7 @@ class ReshapeTransform(Transform):
) )
def _forward_log_det_jacobian(self, x): def _forward_log_det_jacobian(self, x):
# paddle.zeros not support zero dimension Tensor. # TODO(zhouwei): should not set shape to [1], which is []
shape = x.shape[: x.dim() - len(self._in_event_shape)] or [1] shape = x.shape[: x.dim() - len(self._in_event_shape)] or [1]
return paddle.zeros(shape, dtype=x.dtype) return paddle.zeros(shape, dtype=x.dtype)
......
...@@ -103,7 +103,7 @@ def parameterize_func( ...@@ -103,7 +103,7 @@ def parameterize_func(
frame_locals[name].__doc__ = doc_func(f, num, p) frame_locals[name].__doc__ = doc_func(f, num, p)
# Delete original patches to prevent new function from evaluating # Delete original patches to prevent new function from evaluating
# original patching object as well as re-constructed patches. # original patching object as well as re-constrfucted patches.
delete_patches_if_need(f) delete_patches_if_need(f)
f.__test__ = False f.__test__ = False
......
...@@ -191,6 +191,17 @@ class TestAbsTransform(unittest.TestCase): ...@@ -191,6 +191,17 @@ class TestAbsTransform(unittest.TestCase):
def test_inverse_shape(self, shape, expected_shape): def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape) self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([(np.array(1.0), np.array(1.0))])
def test_zerodim(self, input, expected):
x = paddle.to_tensor(input).astype('float32')
self.assertEqual(self._t.forward(x).shape, [])
self.assertEqual(self._t.inverse(x)[0].shape, [])
self.assertEqual(self._t.inverse(x)[1].shape, [])
self.assertEqual(self._t.inverse_log_det_jacobian(x)[0].shape, [])
self.assertEqual(self._t.inverse_log_det_jacobian(x)[1].shape, [])
self.assertEqual(self._t.forward_shape(x.shape), [])
self.assertEqual(self._t.inverse_shape(x.shape), [])
@param.place(config.DEVICES) @param.place(config.DEVICES)
@param.param_cls( @param.param_cls(
...@@ -297,6 +308,18 @@ class TestAffineTransform(unittest.TestCase): ...@@ -297,6 +308,18 @@ class TestAffineTransform(unittest.TestCase):
np.broadcast(np.random.random(shape), self.loc, self.scale).shape, np.broadcast(np.random.random(shape), self.loc, self.scale).shape,
) )
@param.param_func([(np.array(1.0), np.array(1.0))])
def test_zerodim(self, input, expected):
affine = transform.AffineTransform(paddle.zeros([]), paddle.ones([]))
x = paddle.to_tensor(input).astype('float32')
self.assertEqual(affine.forward(x).shape, [])
self.assertEqual(affine.inverse(x).shape, [])
self.assertEqual(affine.forward_log_det_jacobian(x).shape, [])
self.assertEqual(affine.inverse_log_det_jacobian(x).shape, [])
self.assertEqual(affine.forward_shape(x.shape), ())
self.assertEqual(affine.inverse_shape(x.shape), ())
@param.place(config.DEVICES) @param.place(config.DEVICES)
class TestExpTransform(unittest.TestCase): class TestExpTransform(unittest.TestCase):
...@@ -395,6 +418,16 @@ class TestExpTransform(unittest.TestCase): ...@@ -395,6 +418,16 @@ class TestExpTransform(unittest.TestCase):
def test_inverse_shape(self, shape, expected_shape): def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape) self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([(np.array(1.0), np.array(1.0))])
def test_zerodim(self, input, expected):
x = paddle.to_tensor(input).astype('float32')
self.assertEqual(self._t.forward(x).shape, [])
self.assertEqual(self._t.inverse(x).shape, [])
self.assertEqual(self._t.forward_log_det_jacobian(x).shape, [])
self.assertEqual(self._t.inverse_log_det_jacobian(x).shape, [])
self.assertEqual(self._t.forward_shape(x.shape), [])
self.assertEqual(self._t.inverse_shape(x.shape), [])
@param.place(config.DEVICES) @param.place(config.DEVICES)
class TestChainTransform(unittest.TestCase): class TestChainTransform(unittest.TestCase):
...@@ -785,6 +818,18 @@ class TestPowerTransform(unittest.TestCase): ...@@ -785,6 +818,18 @@ class TestPowerTransform(unittest.TestCase):
def test_inverse_shape(self, shape, expected_shape): def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape) self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([(np.array(2.0), np.array(1.0))])
def test_zerodim(self, input, expected):
power = transform.PowerTransform(paddle.full([], 2.0))
x = paddle.to_tensor(input).astype('float32')
self.assertEqual(power.forward(x).shape, [])
self.assertEqual(power.inverse(x).shape, [])
self.assertEqual(power.forward_log_det_jacobian(x).shape, [])
self.assertEqual(power.inverse_log_det_jacobian(x).shape, [])
self.assertEqual(power.forward_shape(x.shape), ())
self.assertEqual(power.inverse_shape(x.shape), ())
@param.place(config.DEVICES) @param.place(config.DEVICES)
class TestTanhTransform(unittest.TestCase): class TestTanhTransform(unittest.TestCase):
...@@ -892,6 +937,16 @@ class TestTanhTransform(unittest.TestCase): ...@@ -892,6 +937,16 @@ class TestTanhTransform(unittest.TestCase):
def test_inverse_shape(self, shape, expected_shape): def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape) self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([(np.array(1.0), np.array(1.0))])
def test_zerodim(self, input, expected):
x = paddle.to_tensor(input).astype('float32')
self.assertEqual(self._t.forward(x).shape, [])
self.assertEqual(self._t.inverse(x).shape, [])
self.assertEqual(self._t.forward_log_det_jacobian(x).shape, [])
self.assertEqual(self._t.inverse_log_det_jacobian(x).shape, [])
self.assertEqual(self._t.forward_shape(x.shape), [])
self.assertEqual(self._t.inverse_shape(x.shape), [])
@param.place(config.DEVICES) @param.place(config.DEVICES)
@param.param_cls( @param.param_cls(
...@@ -965,6 +1020,20 @@ class TestReshapeTransform(unittest.TestCase): ...@@ -965,6 +1020,20 @@ class TestReshapeTransform(unittest.TestCase):
with self.assertRaises(exc): with self.assertRaises(exc):
self._t.inverse_shape(shape) self._t.inverse_shape(shape)
@param.param_func([(np.array(2.0), np.array(1.0))])
def test_zerodim(self, input, expected):
reshape = transform.ReshapeTransform((), (1, 1))
x = paddle.to_tensor(input).astype('float32')
out = reshape.forward(x)
self.assertEqual(out.shape, [1, 1])
self.assertEqual(reshape.inverse(out).shape, [])
# self.assertEqual(reshape.forward_log_det_jacobian(x).shape, [])
# self.assertEqual(reshape.inverse_log_det_jacobian(out).shape, [])
self.assertEqual(reshape.forward_shape(x.shape), (1, 1))
self.assertEqual(reshape.inverse_shape(out.shape), ())
def _np_softplus(x, beta=1.0, threshold=20.0): def _np_softplus(x, beta=1.0, threshold=20.0):
if np.any(beta * x > threshold): if np.any(beta * x > threshold):
...@@ -1031,6 +1100,16 @@ class TestSigmoidTransform(unittest.TestCase): ...@@ -1031,6 +1100,16 @@ class TestSigmoidTransform(unittest.TestCase):
def test_inverse_shape(self, shape, expected_shape): def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape) self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([(np.array(1.0), np.array(1.0))])
def test_zerodim(self, input, expected):
x = paddle.to_tensor(input).astype('float32')
self.assertEqual(self._t.forward(x).shape, [])
self.assertEqual(self._t.inverse(x).shape, [])
self.assertEqual(self._t.forward_log_det_jacobian(x).shape, [])
self.assertEqual(self._t.inverse_log_det_jacobian(x).shape, [])
self.assertEqual(self._t.forward_shape(x.shape), [])
self.assertEqual(self._t.inverse_shape(x.shape), [])
class TestSoftmaxTransform(unittest.TestCase): class TestSoftmaxTransform(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册