未验证 提交 2bd0a946 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] Support broadcast_tensors input 0D and distribution API output 0D (#51721)

上级 0f9ec013
......@@ -771,12 +771,6 @@ void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
target_rank = std::max(target_rank, input_ddim.size());
}
PADDLE_ENFORCE_GT(target_rank,
0,
errors::InvalidArgument("BroadcastTensorsOp requires at "
"least one input tensor to have "
"rank greater than zero"));
std::vector<int64_t> target_dims(target_rank, 0);
// 2. Output dim(axis=x) = max(Inputs dim(axis=x))
for (int index = 0; index < target_rank; index++) {
......
......@@ -37,6 +37,7 @@ struct EigenBroadcast<Eigen::DefaultDevice, T, Rank> {
OutType out,
InType in,
const Array& bcast) {
// Eigen::TensorMap.broadcast not support 0D
out.device(dev) = in.broadcast(bcast);
}
......
......@@ -18,6 +18,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/broadcast_tensors_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
......@@ -48,9 +49,9 @@ void ApplyBroadcast(const Context& ctx,
// expanded dims: "new_input_dims_vec"
Eigen::DSizes<Eigen::DenseIndex, OutRank> bcast_dims;
std::vector<int64_t> new_input_dims_vec(out_rank);
for (int j = 0; j < out_rank; j++) {
int out_axis = out_rank - j - 1;
int in_axis = in_rank - j - 1;
for (int i = 0; i < out_rank; i++) {
int in_axis = in_rank - i - 1;
int out_axis = out_rank - i - 1;
bcast_dims[out_axis] = output_dims[out_axis];
new_input_dims_vec[out_axis] = 1;
......@@ -101,12 +102,18 @@ void BroadcastTensorsKernel(const Context& ctx,
for (size_t i = 0; i < num_ins; i++) {
int out_rank = out_tensors[i]->dims().size();
switch (out_rank) {
SWITCH_OUT_RANK_CASE(1)
SWITCH_OUT_RANK_CASE(2)
SWITCH_OUT_RANK_CASE(3)
SWITCH_OUT_RANK_CASE(4)
SWITCH_OUT_RANK_CASE(5)
SWITCH_OUT_RANK_CASE(6)
case 0: {
const DenseTensor* src = in_tensors[i];
DenseTensor* dst = out_tensors[i];
phi::Copy(ctx, *src, src->place(), false, dst);
break;
}
SWITCH_OUT_RANK_CASE(1)
SWITCH_OUT_RANK_CASE(2)
SWITCH_OUT_RANK_CASE(3)
SWITCH_OUT_RANK_CASE(4)
SWITCH_OUT_RANK_CASE(5)
SWITCH_OUT_RANK_CASE(6)
default: {
PADDLE_THROW(phi::errors::InvalidArgument(
"Target tensor rank out of range"
......
......@@ -60,13 +60,13 @@ class Beta(exponential_family.ExponentialFamily):
# scale input
beta = paddle.distribution.Beta(alpha=0.5, beta=0.5)
print(beta.mean)
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.50000000])
print(beta.variance)
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.12500000])
print(beta.entropy())
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.12500000])
# tensor input with broadcast
......@@ -84,10 +84,10 @@ class Beta(exponential_family.ExponentialFamily):
def __init__(self, alpha, beta):
if isinstance(alpha, numbers.Real):
alpha = paddle.full(shape=[1], fill_value=alpha)
alpha = paddle.full(shape=[], fill_value=alpha)
if isinstance(beta, numbers.Real):
beta = paddle.full(shape=[1], fill_value=beta)
beta = paddle.full(shape=[], fill_value=beta)
self.alpha, self.beta = paddle.broadcast_tensors([alpha, beta])
......
......@@ -62,10 +62,10 @@ class Dirichlet(exponential_family.ExponentialFamily):
dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.]))
print(dirichlet.entropy())
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [-1.24434423])
print(dirichlet.prob(paddle.to_tensor([.3, .5, .6])))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [10.80000114])
"""
......
......@@ -134,7 +134,11 @@ class Distribution:
Returns:
Tensor: generated sample data shape
"""
return sample_shape + self._batch_shape + self._event_shape
return (
tuple(sample_shape)
+ tuple(self._batch_shape)
+ tuple(self._event_shape)
)
def _validate_args(self, *args):
"""
......@@ -173,11 +177,11 @@ class Distribution:
tmp = 0.0
for arg in args:
if isinstance(arg, float):
arg = [arg]
if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
if not isinstance(
arg, (float, list, tuple, np.ndarray, tensor.Variable)
):
raise TypeError(
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".format(
"Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {}".format(
type(arg)
)
)
......
......@@ -61,7 +61,7 @@ class Gumbel(TransformedDistribution):
dist.cdf(value)
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.54523915])
dist.entropy()
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1.57721567])
# Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1.57721567])
dist.rsample([2])
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[0.80463481], [0.91893655]])
......
......@@ -56,7 +56,7 @@ def kl_divergence(p, q):
q = paddle.distribution.Beta(alpha=0.3, beta=0.7)
print(paddle.distribution.kl_divergence(p, q))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.21193528])
"""
......
......@@ -48,7 +48,7 @@ class Laplace(distribution.Distribution):
m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0]))
m.sample() # Laplace distributed with loc=0, scale=1
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
# [3.68546247])
"""
......@@ -209,7 +209,7 @@ class Laplace(distribution.Distribution):
m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0]))
m.entropy()
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.69314718])
"""
return 1 + paddle.log(2 * self.scale)
......@@ -304,14 +304,10 @@ class Laplace(distribution.Distribution):
m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0]))
m.sample() # Laplace distributed with loc=0, scale=1
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
# [3.68546247])
"""
if not isinstance(shape, tuple):
raise TypeError(
f'Expected shape should be tuple[int], but got {type(shape)}'
)
shape = shape if isinstance(shape, tuple) else tuple(shape)
with paddle.no_grad():
return self.rsample(shape)
......@@ -336,22 +332,16 @@ class Laplace(distribution.Distribution):
"""
eps = self._get_eps()
shape = self._extend_shape(shape) or (1,)
shape = self._extend_shape(shape)
uniform = paddle.uniform(
shape=shape,
min=float(np.nextafter(-1, 1)) + eps / 2,
max=1.0 - eps / 2,
dtype=self.loc.dtype,
)
if len(self.scale.shape) == 0 and len(self.loc.shape) == 0:
loc, scale, uniform = paddle.broadcast_tensors(
[self.loc, self.scale, uniform]
)
else:
loc, scale = self.loc, self.scale
return loc - scale * uniform.sign() * paddle.log1p(-uniform.abs())
return self.loc - self.scale * uniform.sign() * paddle.log1p(
-uniform.abs()
)
def _get_eps(self):
"""
......@@ -410,7 +400,7 @@ class Laplace(distribution.Distribution):
m1 = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0]))
m2 = paddle.distribution.Laplace(paddle.to_tensor([1.0]), paddle.to_tensor([0.5]))
m1.kl_divergence(m2)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.04261160])
"""
......
......@@ -72,13 +72,13 @@ class LogNormal(TransformedDistribution):
sample = lognormal_a.sample((2, ))
# a random tensor created by lognormal distribution with shape: [2, 1]
entropy = lognormal_a.entropy()
# [1.4189385] with shape: [1]
# [1.4189385] with shape: []
lp = lognormal_a.log_prob(value_tensor)
# [-0.72069150] with shape: [1]
p = lognormal_a.probs(value_tensor)
# [0.48641577] with shape: [1]
kl = lognormal_a.kl_divergence(lognormal_b)
# [0.34939718] with shape: [1]
# [0.34939718] with shape: []
"""
def __init__(self, loc, scale):
......
......@@ -166,7 +166,7 @@ class Multinomial(distribution.Distribution):
Tensor: entropy value
"""
n = paddle.full(
shape=[1], fill_value=self.total_count, dtype=self.probs.dtype
shape=[], fill_value=self.total_count, dtype=self.probs.dtype
)
support = paddle.arange(
self.total_count + 1, dtype=self.probs.dtype
......
......@@ -77,13 +77,13 @@ class Normal(distribution.Distribution):
sample = normal_a.sample([2])
# a random tensor created by normal distribution with shape: [2, 1]
entropy = normal_a.entropy()
# [1.4189385] with shape: [1]
# [1.4189385] with shape: []
lp = normal_a.log_prob(value_tensor)
# [-1.2389386] with shape: [1]
p = normal_a.probs(value_tensor)
# [0.28969154] with shape: [1]
kl = normal_a.kl_divergence(normal_b)
# [0.34939718] with shape: [1]
# [0.34939718] with shape: []
"""
def __init__(self, loc, scale, name=None):
......@@ -101,7 +101,6 @@ class Normal(distribution.Distribution):
'Normal',
)
self.batch_size_unknown = False
self.all_arg_is_float = False
self.name = name if name is not None else 'Normal'
self.dtype = 'float32'
......@@ -112,7 +111,6 @@ class Normal(distribution.Distribution):
scale = float(scale)
if self._validate_args(loc, scale):
self.batch_size_unknown = True
self.loc = loc
self.scale = scale
self.dtype = convert_dtype(loc.dtype)
......@@ -174,8 +172,7 @@ class Normal(distribution.Distribution):
shape = list(shape)
batch_shape = list((self.loc + self.scale).shape)
name = self.name + '_sample'
if self.batch_size_unknown:
if -1 in batch_shape:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.dtype, 0.0
......@@ -236,9 +233,12 @@ class Normal(distribution.Distribution):
"""
name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.dtype, 0.0
)
if -1 in batch_shape:
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.dtype, 0.0
)
else:
zero_tmp = paddle.full(batch_shape, 0.0, self.dtype)
return paddle.add(
0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + paddle.log(self.scale + zero_tmp),
......
......@@ -368,7 +368,7 @@ class AbsTransform(Transform):
# Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1., 0., 1.])
print(abs.inverse(paddle.to_tensor(1.)))
print(abs.inverse(paddle.to_tensor([1.])))
# (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-1.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1.]))
......@@ -380,7 +380,7 @@ class AbsTransform(Transform):
# 0.))
#Special case handling of 0.
print(abs.inverse(paddle.to_tensor(0.)))
print(abs.inverse(paddle.to_tensor([0.])))
# (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.]))
......
......@@ -84,7 +84,7 @@ class Uniform(distribution.Distribution):
sample = uniform.sample([2])
# a random tensor created by uniform distribution with shape: [2, 1]
entropy = uniform.entropy()
# [0.6931472] with shape: [1]
# [0.6931472] with shape: []
lp = uniform.log_prob(value_tensor)
# [-0.6931472] with shape: [1]
p = uniform.probs(value_tensor)
......@@ -117,7 +117,6 @@ class Uniform(distribution.Distribution):
high = float(high)
if self._validate_args(low, high):
self.batch_size_unknown = True
self.low = low
self.high = high
self.dtype = convert_dtype(low.dtype)
......@@ -159,7 +158,7 @@ class Uniform(distribution.Distribution):
name = self.name + '_sample'
batch_shape = list((self.low + self.high).shape)
if self.batch_size_unknown:
if -1 in batch_shape:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.dtype, 0.0
......
......@@ -69,7 +69,7 @@ class TestIndependent(unittest.TestCase):
def test_rsample(self):
shape = [5, 10, 8]
expected_shape = (5, 10, 8, 1)
expected_shape = (5, 10, 8)
data = self._t.rsample(shape)
self.assertEqual(tuple(data.shape), expected_shape)
self.assertEqual(data.dtype, self.base.loc.dtype)
......
......@@ -768,6 +768,44 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out3.grad.shape, [3, 3])
np.testing.assert_allclose(out3.grad, 1.0)
def test_broadcast_tensors(self):
# 1) x is 0D, y is 0D
x1 = paddle.full([], 2.0)
x1.stop_gradient = False
x2 = paddle.full([], 2.0)
x2.stop_gradient = False
out1, out2 = paddle.broadcast_tensors([x1, x2])
# backward has bug now
# out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
# self.assertEqual(x1.grad.shape, [])
# 2) x is ND , y is 0D
x1 = paddle.full([2, 3], 2.0)
x1.stop_gradient = False
x2 = paddle.full([], 2.0)
x2.stop_gradient = False
out1, out2 = paddle.broadcast_tensors([x1, x2])
# out1.backward()
self.assertEqual(out1.shape, [2, 3])
self.assertEqual(out2.shape, [2, 3])
# self.assertEqual(x1.grad.shape, [2, 3])
# 3) x is 0D , y is ND
x1 = paddle.full([], 2.0)
x1.stop_gradient = False
x2 = paddle.full([2, 3], 2.0)
x2.stop_gradient = False
out1, out2 = paddle.broadcast_tensors([x1, x2])
# out1.backward()
self.assertEqual(out1.shape, [2, 3])
self.assertEqual(out2.shape, [2, 3])
# self.assertEqual(x1.grad.shape, [2, 3])
def test_broadcast_shape(self):
x = []
y = [3, 5]
......@@ -3540,6 +3578,37 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, (0))
np.testing.assert_array_equal(res[0], np.array([]))
def test_broadcast_tensors(self):
# 1) x is 0D, y is 0D
x1 = paddle.full([], 2.0)
x1.stop_gradient = False
x2 = paddle.full([], 2.0)
x2.stop_gradient = False
out1, out2 = paddle.broadcast_tensors([x1, x2])
self.assertEqual(out1.shape, ())
self.assertEqual(out2.shape, ())
# 2) x is ND , y is 0D
x1 = paddle.full([2, 3], 2.0)
x1.stop_gradient = False
x2 = paddle.full([], 2.0)
x2.stop_gradient = False
out1, out2 = paddle.broadcast_tensors([x1, x2])
self.assertEqual(out1.shape, (2, 3))
self.assertEqual(out2.shape, (2, 3))
# 3) x is 0D , y is ND
x1 = paddle.full([], 2.0)
x1.stop_gradient = False
x2 = paddle.full([2, 3], 2.0)
x2.stop_gradient = False
out1, out2 = paddle.broadcast_tensors([x1, x2])
self.assertEqual(out1.shape, (2, 3))
self.assertEqual(out2.shape, (2, 3))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......@@ -4114,5 +4183,129 @@ class TestAsComplex(unittest.TestCase):
paddle.disable_static()
class TestDistribution(unittest.TestCase):
def setUp(self):
self.x = paddle.full([], 2.0)
def test_Categorical(self):
logits = paddle.rand([6])
d = paddle.distribution.Categorical(logits)
self.assertEqual(d.sample([]).shape, [])
self.assertEqual(d.probs(paddle.full([], 2, dtype='int64')).shape, [])
self.assertEqual(
d.log_prob(paddle.full([], 2, dtype='int64')).shape, []
)
# because use paddle.sum
# self.assertEqual(d.entropy().shape, [])
def test_Normal(self):
normal = paddle.distribution.Normal(0.0, 3.0)
self.assertEqual(normal.sample([]).shape, [])
self.assertEqual(normal.rsample([]).shape, [])
self.assertEqual(normal.mean.shape, [])
self.assertEqual(normal.variance.shape, [])
self.assertEqual(normal.probs(self.x).shape, [])
self.assertEqual(normal.log_prob(self.x).shape, [])
self.assertEqual(normal.entropy().shape, [])
normal = paddle.distribution.Normal(
paddle.full([], 0.0), paddle.full([], 3.0)
)
self.assertEqual(normal.sample([]).shape, [])
self.assertEqual(normal.rsample([]).shape, [])
self.assertEqual(normal.mean.shape, [])
self.assertEqual(normal.variance.shape, [])
self.assertEqual(normal.probs(self.x).shape, [])
self.assertEqual(normal.log_prob(self.x).shape, [])
self.assertEqual(normal.entropy().shape, [])
def test_Uniform(self):
uniform = paddle.distribution.Uniform(0.0, 1.0)
self.assertEqual(uniform.sample([]).shape, [])
self.assertEqual(uniform.probs(self.x).shape, [])
self.assertEqual(uniform.log_prob(self.x).shape, [])
self.assertEqual(uniform.entropy().shape, [])
uniform = paddle.distribution.Uniform(
paddle.full([], 0.0), paddle.full([], 1.0)
)
self.assertEqual(uniform.sample([]).shape, [])
self.assertEqual(uniform.probs(self.x).shape, [])
self.assertEqual(uniform.log_prob(self.x).shape, [])
self.assertEqual(uniform.entropy().shape, [])
def test_Beta(self):
beta = paddle.distribution.Beta(alpha=0.5, beta=0.5)
self.assertEqual(beta.sample([]).shape, [])
self.assertEqual(beta.mean.shape, [])
self.assertEqual(beta.variance.shape, [])
# because use paddle.sum
# self.assertEqual(beta.prob(self.x).shape, [])
# self.assertEqual(beta.log_prob(self.x).shape, [])
# self.assertEqual(beta.entropy().shape, [])
def test_kl_divergence(self):
p = paddle.distribution.Beta(alpha=0.5, beta=0.5)
q = paddle.distribution.Beta(alpha=0.2, beta=1.0)
kl = paddle.distribution.kl_divergence(p, q)
self.assertEqual(kl.shape, [])
def test_TransformedDistribution(self):
d = paddle.distribution.TransformedDistribution(
paddle.distribution.Normal(0.0, 1.0),
[
paddle.distribution.AffineTransform(
paddle.full([], 1.0), paddle.full([], 2.0)
)
],
)
self.assertEqual(d.sample([]).shape, [])
self.assertEqual(d.rsample([]).shape, [])
self.assertEqual(d.prob(self.x).shape, [])
self.assertEqual(d.log_prob(self.x).shape, [])
def test_Laplace(self):
d = paddle.distribution.Laplace(0.0, 1.0)
self.assertEqual(d.sample([]).shape, [])
self.assertEqual(d.rsample([]).shape, [])
self.assertEqual(d.mean.shape, [])
self.assertEqual(d.stddev.shape, [])
self.assertEqual(d.variance.shape, [])
self.assertEqual(d.prob(self.x).shape, [])
self.assertEqual(d.log_prob(self.x).shape, [])
self.assertEqual(d.cdf(self.x).shape, [])
self.assertEqual(d.icdf(self.x).shape, [])
self.assertEqual(d.entropy().shape, [])
def test_LogNormal(self):
d = paddle.distribution.LogNormal(0.0, 1.0)
self.assertEqual(d.sample([]).shape, [])
self.assertEqual(d.mean.shape, [])
self.assertEqual(d.variance.shape, [])
self.assertEqual(d.entropy().shape, [])
self.assertEqual(d.probs(self.x).shape, [])
def test_Gumbel(self):
d = paddle.distribution.Gumbel(0.0, 1.0)
self.assertEqual(d.sample([]).shape, [])
self.assertEqual(d.rsample([]).shape, [])
self.assertEqual(d.mean.shape, [])
self.assertEqual(d.variance.shape, [])
self.assertEqual(d.stddev.shape, [])
self.assertEqual(d.prob(self.x).shape, [])
self.assertEqual(d.log_prob(self.x).shape, [])
self.assertEqual(d.cdf(self.x).shape, [])
self.assertEqual(d.entropy().shape, [])
def test_Multinomial(self):
d = paddle.distribution.Multinomial(
10, paddle.to_tensor([0.2, 0.3, 0.5])
)
# because use paddle.sum
# self.assertEqual(d.prob(self.x).shape, [])
# self.assertEqual(d.log_prob(self.x).shape, [])
# self.assertEqual(d.entropy().shape, [])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册