未验证 提交 6735a37a 编写于 作者: X Xiaoxu Chen 提交者: GitHub

Add probability distribution transformation APIs (#40536)

* add random varaiable transformations API for paddle's distribution package

* add TransformedDistribution API for paddle's probability distribution package

* add random variable transformation unitests for static graph

* replace math.prod which not support python3.7 with functools.reduce

* add Independent and TransformedDistribution distribution

* add unittests for constraint

* fix typo and AffineTransform sample code error

* add mean,variance,rsample abstract method for Distribution
上级 7dfd3846
...@@ -12,15 +12,20 @@ ...@@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .beta import Beta from paddle.distribution import transform
from .categorical import Categorical from paddle.distribution.beta import Beta
from .dirichlet import Dirichlet from paddle.distribution.categorical import Categorical
from .distribution import Distribution from paddle.distribution.dirichlet import Dirichlet
from .exponential_family import ExponentialFamily from paddle.distribution.distribution import Distribution
from .kl import kl_divergence, register_kl from paddle.distribution.exponential_family import ExponentialFamily
from .multinomial import Multinomial from paddle.distribution.independent import Independent
from .normal import Normal from paddle.distribution.kl import kl_divergence, register_kl
from .uniform import Uniform from paddle.distribution.multinomial import Multinomial
from paddle.distribution.normal import Normal
from paddle.distribution.transform import * # noqa: F403
from paddle.distribution.transformed_distribution import \
TransformedDistribution
from paddle.distribution.uniform import Uniform
__all__ = [ # noqa __all__ = [ # noqa
'Beta', 'Beta',
...@@ -33,4 +38,8 @@ __all__ = [ # noqa ...@@ -33,4 +38,8 @@ __all__ = [ # noqa
'Uniform', 'Uniform',
'kl_divergence', 'kl_divergence',
'register_kl', 'register_kl',
'Independent',
'TransformedDistribution'
] ]
__all__.extend(transform.__all__)
...@@ -14,12 +14,10 @@ ...@@ -14,12 +14,10 @@
import numbers import numbers
import paddle import paddle
from paddle.distribution import dirichlet, exponential_family
from .dirichlet import Dirichlet
from .exponential_family import ExponentialFamily
class Beta(exponential_family.ExponentialFamily):
class Beta(ExponentialFamily):
r""" r"""
Beta distribution parameterized by alpha and beta. Beta distribution parameterized by alpha and beta.
...@@ -93,7 +91,8 @@ class Beta(ExponentialFamily): ...@@ -93,7 +91,8 @@ class Beta(ExponentialFamily):
self.alpha, self.beta = paddle.broadcast_tensors([alpha, beta]) self.alpha, self.beta = paddle.broadcast_tensors([alpha, beta])
self._dirichlet = Dirichlet(paddle.stack([self.alpha, self.beta], -1)) self._dirichlet = dirichlet.Dirichlet(
paddle.stack([self.alpha, self.beta], -1))
super(Beta, self).__init__(self._dirichlet._batch_shape) super(Beta, self).__init__(self._dirichlet._batch_shape)
......
...@@ -18,18 +18,18 @@ import warnings ...@@ -18,18 +18,18 @@ import warnings
import numpy as np import numpy as np
import paddle import paddle
from paddle import _C_ops from paddle import _C_ops
from paddle.distribution import distribution
from ..fluid import core from paddle.fluid import core
from ..fluid.data_feeder import (check_dtype, check_type, from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype) check_variable_and_dtype, convert_dtype)
from ..fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor) elementwise_mul, elementwise_sub, nn, ops,
from ..tensor import arange, concat, gather_nd, multinomial tensor)
from .distribution import Distribution from paddle.tensor import arange, concat, gather_nd, multinomial
class Categorical(Distribution): class Categorical(distribution.Distribution):
r""" r"""
Categorical distribution is a discrete probability distribution that Categorical distribution is a discrete probability distribution that
describes the possible results of a random variable that can take on describes the possible results of a random variable that can take on
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class Constraint(object):
"""Constraint condition for random variable.
"""
def __call__(self, value):
raise NotImplementedError
class Real(Constraint):
def __call__(self, value):
return value == value
class Range(Constraint):
def __init__(self, lower, upper):
self._lower = lower
self._upper = upper
super(Range, self).__init__()
def __call__(self, value):
return self._lower <= value <= self._upper
class Positive(Constraint):
def __call__(self, value):
return value >= 0.
class Simplex(Constraint):
def __call__(self, value):
return paddle.all(value >= 0, axis=-1) and (
(value.sum(-1) - 1).abs() < 1e-6)
real = Real()
positive = Positive()
simplex = Simplex()
...@@ -13,14 +13,13 @@ ...@@ -13,14 +13,13 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.distribution import exponential_family
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.framework import _non_static_mode
from ..fluid.layer_helper import LayerHelper
from .exponential_family import ExponentialFamily
class Dirichlet(exponential_family.ExponentialFamily):
class Dirichlet(ExponentialFamily):
r""" r"""
Dirichlet distribution with parameter "concentration". Dirichlet distribution with parameter "concentration".
......
...@@ -27,14 +27,14 @@ import warnings ...@@ -27,14 +27,14 @@ import warnings
import numpy as np import numpy as np
import paddle import paddle
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid import core
from ..fluid import core from paddle.fluid.data_feeder import (check_dtype, check_type,
from ..fluid.data_feeder import (check_dtype, check_type, check_variable_and_dtype, convert_dtype)
check_variable_and_dtype, convert_dtype) from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from ..fluid.framework import _non_static_mode from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, elementwise_mul, elementwise_sub, nn, ops,
elementwise_mul, elementwise_sub, nn, ops, tensor) tensor)
from ..tensor import arange, concat, gather_nd, multinomial from paddle.tensor import arange, concat, gather_nd, multinomial
class Distribution(object): class Distribution(object):
...@@ -78,10 +78,24 @@ class Distribution(object): ...@@ -78,10 +78,24 @@ class Distribution(object):
""" """
return self._event_shape return self._event_shape
@property
def mean(self):
"""Mean of distribution"""
raise NotImplementedError
@property
def variance(self):
"""Variance of distribution"""
raise NotImplementedError
def sample(self, shape=()): def sample(self, shape=()):
"""Sampling from the distribution.""" """Sampling from the distribution."""
raise NotImplementedError raise NotImplementedError
def rsample(self, shape=()):
"""reparameterized sample"""
raise NotImplementedError
def entropy(self): def entropy(self):
"""The entropy of the distribution.""" """The entropy of the distribution."""
raise NotImplementedError raise NotImplementedError
...@@ -96,7 +110,7 @@ class Distribution(object): ...@@ -96,7 +110,7 @@ class Distribution(object):
Args: Args:
value (Tensor): value which will be evaluated value (Tensor): value which will be evaluated
""" """
raise NotImplementedError return self.log_prob(value).exp()
def log_prob(self, value): def log_prob(self, value):
"""Log probability density/mass function.""" """Log probability density/mass function."""
......
...@@ -13,12 +13,11 @@ ...@@ -13,12 +13,11 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.distribution import distribution
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from ..fluid.framework import _non_static_mode
from .distribution import Distribution
class ExponentialFamily(distribution.Distribution):
class ExponentialFamily(Distribution):
r""" r"""
ExponentialFamily is the base class for probability distributions belonging ExponentialFamily is the base class for probability distributions belonging
to exponential family, whose probability mass/density function has the to exponential family, whose probability mass/density function has the
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distribution import distribution
class Independent(distribution.Distribution):
r"""
Reinterprets some of the batch dimensions of a distribution as event dimensions.
This is mainly useful for changing the shape of the result of
:meth:`log_prob`.
Args:
base (Distribution): The base distribution.
reinterpreted_batch_rank (int): The number of batch dimensions to
reinterpret as event dimensions.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import independent
beta = paddle.distribution.Beta(paddle.to_tensor([0.5, 0.5]), paddle.to_tensor([0.5, 0.5]))
print(beta.batch_shape, beta.event_shape)
# (2,) ()
print(beta.log_prob(paddle.to_tensor(0.2)))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-0.22843921, -0.22843921])
reinterpreted_beta = independent.Independent(beta, 1)
print(reinterpreted_beta.batch_shape, reinterpreted_beta.event_shape)
# () (2,)
print(reinterpreted_beta.log_prob(paddle.to_tensor([0.2, 0.2])))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-0.45687842])
"""
def __init__(self, base, reinterpreted_batch_rank):
if not isinstance(base, distribution.Distribution):
raise TypeError(
f"Expected type of 'base' is Distribution, but got {type(base)}")
if not (0 < reinterpreted_batch_rank <= len(base.batch_shape)):
raise ValueError(
f"Expected 0 < reinterpreted_batch_rank <= {len(base.batch_shape)}, but got {reinterpreted_batch_rank}"
)
self._base = base
self._reinterpreted_batch_rank = reinterpreted_batch_rank
shape = base.batch_shape + base.event_shape
super(Independent, self).__init__(
batch_shape=shape[:len(base.batch_shape) -
reinterpreted_batch_rank],
event_shape=shape[len(base.batch_shape) -
reinterpreted_batch_rank:])
@property
def mean(self):
return self._base.mean
@property
def variance(self):
return self._base.variance
def sample(self, shape=()):
return self._base.sample(shape)
def log_prob(self, value):
return self._sum_rightmost(
self._base.log_prob(value), self._reinterpreted_batch_rank)
def prob(self, value):
return self.log_prob(value).exp()
def entropy(self):
return self._sum_rightmost(self._base.entropy(),
self._reinterpreted_batch_rank)
def _sum_rightmost(self, value, n):
return value.sum(list(range(-n, 0))) if n > 0 else value
...@@ -15,15 +15,14 @@ import functools ...@@ -15,15 +15,14 @@ import functools
import warnings import warnings
import paddle import paddle
from paddle.distribution.beta import Beta
from ..fluid.framework import _non_static_mode from paddle.distribution.categorical import Categorical
from .beta import Beta from paddle.distribution.dirichlet import Dirichlet
from .categorical import Categorical from paddle.distribution.distribution import Distribution
from .dirichlet import Dirichlet from paddle.distribution.exponential_family import ExponentialFamily
from .distribution import Distribution from paddle.distribution.normal import Normal
from .exponential_family import ExponentialFamily from paddle.distribution.uniform import Uniform
from .normal import Normal from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from .uniform import Uniform
__all__ = ["register_kl", "kl_divergence"] __all__ = ["register_kl", "kl_divergence"]
...@@ -207,5 +206,4 @@ def _kl_expfamily_expfamily(p, q): ...@@ -207,5 +206,4 @@ def _kl_expfamily_expfamily(p, q):
def _sum_rightmost(value, n): def _sum_rightmost(value, n):
"""Sum elements along rightmost n dim"""
return value.sum(list(range(-n, 0))) if n > 0 else value return value.sum(list(range(-n, 0))) if n > 0 else value
...@@ -17,18 +17,17 @@ import warnings ...@@ -17,18 +17,17 @@ import warnings
import numpy as np import numpy as np
from paddle import _C_ops from paddle import _C_ops
from paddle.distribution import distribution
from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops,
tensor)
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import _non_static_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from .distribution import Distribution
class Normal(distribution.Distribution):
class Normal(Distribution):
r"""The Normal distribution with location `loc` and `scale` parameters. r"""The Normal distribution with location `loc` and `scale` parameters.
Mathematical details Mathematical details
...@@ -129,6 +128,7 @@ class Normal(Distribution): ...@@ -129,6 +128,7 @@ class Normal(Distribution):
if self.dtype != convert_dtype(self.loc.dtype): if self.dtype != convert_dtype(self.loc.dtype):
self.loc = tensor.cast(self.loc, dtype=self.dtype) self.loc = tensor.cast(self.loc, dtype=self.dtype)
self.scale = tensor.cast(self.scale, dtype=self.dtype) self.scale = tensor.cast(self.scale, dtype=self.dtype)
super(Normal, self).__init__(self.loc.shape)
def sample(self, shape, seed=0): def sample(self, shape, seed=0):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
......
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from paddle.distribution import distribution
from paddle.distribution import transform
from paddle.distribution import independent
class TransformedDistribution(distribution.Distribution):
r"""
Applies a sequence of Transforms to a base distribution.
Args:
base (Distribution): The base distribution.
transforms (Sequence[Transform]): A sequence of ``Transform`` .
Examples:
.. code-block:: python
import paddle
from paddle.distribution import transformed_distribution
d = transformed_distribution.TransformedDistribution(
paddle.distribution.Normal(0., 1.),
[paddle.distribution.AffineTransform(paddle.to_tensor(1.), paddle.to_tensor(2.))]
)
print(d.sample([10]))
# Tensor(shape=[10], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-0.10697651, 3.33609009, -0.86234951, 5.07457638, 0.75925219,
# -4.17087793, 2.22579336, -0.93845034, 0.66054249, 1.50957513])
print(d.log_prob(paddle.to_tensor(0.5)))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-1.64333570])
"""
def __init__(self, base, transforms):
if not isinstance(base, distribution.Distribution):
raise TypeError(
f"Expected type of 'base' is Distribution, but got {type(base)}."
)
if not isinstance(transforms, typing.Sequence):
raise TypeError(
f"Expected type of 'transforms' is Sequence[Transform] or Chain, but got {type(transforms)}."
)
if not all(isinstance(t, transform.Transform) for t in transforms):
raise TypeError("All element of transforms must be Transform type.")
chain = transform.ChainTransform(transforms)
if len(base.batch_shape + base.event_shape) < chain._domain.event_rank:
raise ValueError(
f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}."
)
if chain._domain.event_rank > len(base.event_shape):
base = independent.Independent(
(base, chain._domain.event_rank - len(base.event_shape)))
self._base = base
self._transforms = transforms
transformed_shape = chain.forward_shape(base.batch_shape +
base.event_shape)
transformed_event_rank = chain._codomain.event_rank + \
max(len(base.event_shape)-chain._domain.event_rank, 0)
super(TransformedDistribution, self).__init__(
transformed_shape[:len(transformed_shape) - transformed_event_rank],
transformed_shape[:len(transformed_shape) - transformed_event_rank])
def sample(self, shape=()):
"""Sample from ``TransformedDistribution``.
Args:
shape (tuple, optional): The sample shape. Defaults to ().
Returns:
[Tensor]: The sample result.
"""
x = self._base.sample(shape)
for t in self._transforms:
x = t.forward(x)
return x
def log_prob(self, value):
"""The log probability evaluated at value.
Args:
value (Tensor): The value to be evaluated.
Returns:
Tensor: The log probability.
"""
log_prob = 0.0
y = value
event_rank = len(self.event_shape)
for t in reversed(self._transforms):
x = t.inverse(y)
event_rank += t._domain.event_rank - t._codomain.event_rank
log_prob = log_prob - \
_sum_rightmost(t.forward_log_det_jacobian(
x), event_rank-t._domain.event_rank)
y = x
log_prob += _sum_rightmost(
self._base.log_prob(y), event_rank - len(self._base.event_shape))
return log_prob
def _sum_rightmost(value, n):
return value.sum(list(range(-n, 0))) if n > 0 else value
...@@ -17,18 +17,18 @@ import warnings ...@@ -17,18 +17,18 @@ import warnings
import numpy as np import numpy as np
from paddle import _C_ops from paddle import _C_ops
from paddle.distribution import distribution
from ..fluid import core from paddle.fluid import core
from ..fluid.data_feeder import (check_dtype, check_type, from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype) check_variable_and_dtype, convert_dtype)
from ..fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor) elementwise_mul, elementwise_sub, nn, ops,
from ..tensor import arange, concat, gather_nd, multinomial tensor)
from .distribution import Distribution from paddle.tensor import arange, concat, gather_nd, multinomial
class Uniform(Distribution): class Uniform(distribution.Distribution):
r"""Uniform distribution with `low` and `high` parameters. r"""Uniform distribution with `low` and `high` parameters.
Mathematical Details Mathematical Details
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distribution import constraint
class Variable(object):
"""Random variable of probability distribution.
Args:
is_discrete (bool): Is the variable discrete or continuous.
event_rank (int): The rank of event dimensions.
"""
def __init__(self, is_discrete=False, event_rank=0, constraint=None):
self._is_discrete = is_discrete
self._event_rank = event_rank
self._constraint = constraint
@property
def is_discrete(self):
return self._is_discrete
@property
def event_rank(self):
return self._event_rank
def constraint(self, value):
"""Check whether the 'value' meet the constraint conditions of this
random variable."""
return self._constraint(value)
class Real(Variable):
def __init__(self, event_rank=0):
super(Real, self).__init__(False, event_rank, constraint.real)
class Positive(Variable):
def __init__(self, event_rank=0):
super(Positive, self).__init__(False, event_rank, constraint.positive)
class Independent(Variable):
"""Reinterprets some of the batch axes of variable as event axes.
Args:
base (Variable): Base variable.
reinterpreted_batch_rank (int): The rightmost batch rank to be
reinterpreted.
"""
def __init__(self, base, reinterpreted_batch_rank):
self._base = base
self._reinterpreted_batch_rank = reinterpreted_batch_rank
super(Independent, self).__init__(
base.is_discrete, base.event_rank + reinterpreted_batch_rank)
def constraint(self, value):
ret = self._base.constraint(value)
if ret.dim() < self._reinterpreted_batch_rank:
raise ValueError(
"Input dimensions must be equal or grater than {}".format(
self._reinterpreted_batch_rank))
return ret.reshape(ret.shape[:ret.dim() - self.reinterpreted_batch_rank]
+ (-1, )).all(-1)
class Stack(Variable):
def __init__(self, vars, axis=0):
self._vars = vars
self._axis = axis
@property
def is_discrete(self):
return any(var.is_discrete for var in self._vars)
@property
def event_rank(self):
rank = max(var.event_rank for var in self._vars)
if self._axis + rank < 0:
rank += 1
return rank
def constraint(self, value):
if not (-value.dim() <= self._axis < value.dim()):
raise ValueError(
f'Input dimensions {value.dim()} should be grater than stack '
f'constraint axis {self._axis}.')
return paddle.stack([
var.check(value)
for var, value in zip(self._vars, paddle.unstack(value, self._axis))
], self._axis)
real = Real()
positive = Positive()
...@@ -11,11 +11,6 @@ ...@@ -11,11 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import sys
import numpy as np
import paddle import paddle
DEVICES = [paddle.CPUPlace()] DEVICES = [paddle.CPUPlace()]
...@@ -34,66 +29,3 @@ RTOL = { ...@@ -34,66 +29,3 @@ RTOL = {
'complex128': 1e-5 'complex128': 1e-5
} }
ATOL = {'float32': 0.0, 'complex64': 0, 'float64': 0.0, 'complex128': 0} ATOL = {'float32': 0.0, 'complex64': 0, 'float64': 0.0, 'complex128': 0}
def xrand(shape=(10, 10, 10), dtype=DEFAULT_DTYPE, min=1.0, max=10.0):
return ((np.random.rand(*shape).astype(dtype)) * (max - min) + min)
def place(devices, key='place'):
def decorate(cls):
module = sys.modules[cls.__module__].__dict__
raw_classes = {
k: v
for k, v in module.items() if k.startswith(cls.__name__)
}
for raw_name, raw_cls in raw_classes.items():
for d in devices:
test_cls = dict(raw_cls.__dict__)
test_cls.update({key: d})
new_name = raw_name + '.' + d.__class__.__name__
module[new_name] = type(new_name, (raw_cls, ), test_cls)
del module[raw_name]
return cls
return decorate
def parameterize(fields, values=None):
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]
def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for k, v in enumerate(params):
test_cls = dict(cls.__dict__)
test_cls.update(v)
name = cls.__name__ + str(k)
name = name + '.' + v.get('suffix') if v.get('suffix') else name
test_cls_module[name] = type(name, (cls, ), test_cls)
for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls
return decorate
@contextlib.contextmanager
def stgraph(func, *args):
"""static graph exec context"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = func(input, n, axes, norm)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import functools
import inspect
import re
import sys
import numpy as np
import config
TEST_CASE_NAME = 'suffix'
def xrand(shape=(10, 10, 10), dtype=config.DEFAULT_DTYPE, min=1.0, max=10.0):
return ((np.random.rand(*shape).astype(dtype)) * (max - min) + min)
def place(devices, key='place'):
def decorate(cls):
module = sys.modules[cls.__module__].__dict__
raw_classes = {
k: v
for k, v in module.items() if k.startswith(cls.__name__)
}
for raw_name, raw_cls in raw_classes.items():
for d in devices:
test_cls = dict(raw_cls.__dict__)
test_cls.update({key: d})
new_name = raw_name + '.' + d.__class__.__name__
module[new_name] = type(new_name, (raw_cls, ), test_cls)
del module[raw_name]
return cls
return decorate
def parameterize_cls(fields, values=None):
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]
def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for k, v in enumerate(params):
test_cls = dict(cls.__dict__)
test_cls.update(v)
name = cls.__name__ + str(k)
name = name + '.' + v.get('suffix') if v.get('suffix') else name
test_cls_module[name] = type(name, (cls, ), test_cls)
for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls
return decorate
def parameterize_func(input, name_func=None, doc_func=None,
skip_on_empty=False):
doc_func = doc_func or default_doc_func
name_func = name_func or default_name_func
def wrapper(f, instance=None):
frame_locals = inspect.currentframe().f_back.f_locals
parameters = input_as_callable(input)()
if not parameters:
if not skip_on_empty:
raise ValueError(
"Parameters iterable is empty (hint: use "
"`parameterized.expand([], skip_on_empty=True)` to skip "
"this test when the input is empty)")
return wraps(f)(skip_on_empty_helper)
digits = len(str(len(parameters) - 1))
for num, p in enumerate(parameters):
name = name_func(
f, "{num:0>{digits}}".format(
digits=digits, num=num), p)
# If the original function has patches applied by 'mock.patch',
# re-construct all patches on the just former decoration layer
# of param_as_standalone_func so as not to share
# patch objects between new functions
nf = reapply_patches_if_need(f)
frame_locals[name] = param_as_standalone_func(p, nf, name)
frame_locals[name].__doc__ = doc_func(f, num, p)
# Delete original patches to prevent new function from evaluating
# original patching object as well as re-constructed patches.
delete_patches_if_need(f)
f.__test__ = False
return wrapper
def reapply_patches_if_need(func):
def dummy_wrapper(orgfunc):
@wraps(orgfunc)
def dummy_func(*args, **kwargs):
return orgfunc(*args, **kwargs)
return dummy_func
if hasattr(func, 'patchings'):
func = dummy_wrapper(func)
tmp_patchings = func.patchings
delattr(func, 'patchings')
for patch_obj in tmp_patchings:
func = patch_obj.decorate_callable(func)
return func
def delete_patches_if_need(func):
if hasattr(func, 'patchings'):
func.patchings[:] = []
def default_name_func(func, num, p):
base_name = func.__name__
name_suffix = "_%s" % (num, )
if len(p.args) > 0 and isinstance(p.args[0], str):
name_suffix += "_" + to_safe_name(p.args[0])
return base_name + name_suffix
def default_doc_func(func, num, p):
if func.__doc__ is None:
return None
all_args_with_values = parameterized_argument_value_pairs(func, p)
# Assumes that the function passed is a bound method.
descs = ["%s=%s" % (n, short_repr(v)) for n, v in all_args_with_values]
# The documentation might be a multiline string, so split it
# and just work with the first string, ignoring the period
# at the end if there is one.
first, nl, rest = func.__doc__.lstrip().partition("\n")
suffix = ""
if first.endswith("."):
suffix = "."
first = first[:-1]
args = "%s[with %s]" % (len(first) and " " or "", ", ".join(descs))
return "".join(to_text(x) for x in [first.rstrip(), args, suffix, nl, rest])
def param_as_standalone_func(p, func, name):
@functools.wraps(func)
def standalone_func(*a):
return func(*(a + p.args), **p.kwargs)
standalone_func.__name__ = name
# place_as is used by py.test to determine what source file should be
# used for this test.
standalone_func.place_as = func
# Remove __wrapped__ because py.test will try to look at __wrapped__
# to determine which parameters should be used with this test case,
# and obviously we don't need it to do any parameterization.
try:
del standalone_func.__wrapped__
except AttributeError:
pass
return standalone_func
def input_as_callable(input):
if callable(input):
return lambda: check_input_values(input())
input_values = check_input_values(input)
return lambda: input_values
def check_input_values(input_values):
if not isinstance(input_values, list):
input_values = list(input_values)
return [param.from_decorator(p) for p in input_values]
def skip_on_empty_helper(*a, **kw):
raise SkipTest("parameterized input is empty")
_param = collections.namedtuple("param", "args kwargs")
class param(_param):
def __new__(cls, *args, **kwargs):
return _param.__new__(cls, args, kwargs)
@classmethod
def explicit(cls, args=None, kwargs=None):
""" Creates a ``param`` by explicitly specifying ``args`` and
``kwargs``::
>>> param.explicit([1,2,3])
param(*(1, 2, 3))
>>> param.explicit(kwargs={"foo": 42})
param(*(), **{"foo": "42"})
"""
args = args or ()
kwargs = kwargs or {}
return cls(*args, **kwargs)
@classmethod
def from_decorator(cls, args):
""" Returns an instance of ``param()`` for ``@parameterized`` argument
``args``::
>>> param.from_decorator((42, ))
param(args=(42, ), kwargs={})
>>> param.from_decorator("foo")
param(args=("foo", ), kwargs={})
"""
if isinstance(args, param):
return args
elif isinstance(args, str):
args = (args, )
try:
return cls(*args)
except TypeError as e:
if "after * must be" not in str(e):
raise
raise TypeError(
"Parameters must be tuples, but %r is not (hint: use '(%r, )')"
% (args, args), )
def __repr__(self):
return "param(*%r, **%r)" % self
def to_safe_name(s):
return str(re.sub("[^a-zA-Z0-9_]+", "_", s))
@contextlib.contextmanager
def stgraph(func, *args):
"""static graph exec context"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = func(input, n, axes, norm)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
# alias
parameterize = parameterize_func
param_cls = parameterize_cls
param_func = parameterize_func
...@@ -22,6 +22,7 @@ from paddle.distribution import * ...@@ -22,6 +22,7 @@ from paddle.distribution import *
from paddle.fluid import layers from paddle.fluid import layers
import config import config
import parameterize
paddle.enable_static() paddle.enable_static()
...@@ -132,11 +133,12 @@ class DistributionTestName(unittest.TestCase): ...@@ -132,11 +133,12 @@ class DistributionTestName(unittest.TestCase):
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob') self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'batch_shape', 'event_shape'), @parameterize.parameterize_cls(
[('test-tuple', (10, 20), (parameterize.TEST_CASE_NAME, 'batch_shape', 'event_shape'),
(10, 20)), ('test-list', [100, 100], [100, 200, 300]), [('test-tuple', (10, 20), (10, 20)),
('test-null-eventshape', (100, 100), ())]) ('test-list', [100, 100], [100, 200, 300]), ('test-null-eventshape',
(100, 100), ())])
class TestDistributionShape(unittest.TestCase): class TestDistributionShape(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.disable_static() paddle.disable_static()
...@@ -156,7 +158,7 @@ class TestDistributionShape(unittest.TestCase): ...@@ -156,7 +158,7 @@ class TestDistributionShape(unittest.TestCase):
def test_prob(self): def test_prob(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
self.dist.prob(paddle.to_tensor(config.xrand())) self.dist.prob(paddle.to_tensor(parameterize.xrand()))
def test_extend_shape(self): def test_extend_shape(self):
shapes = [(34, 20), (56, ), ()] shapes = [(34, 20), (56, ), ()]
...@@ -164,3 +166,24 @@ class TestDistributionShape(unittest.TestCase): ...@@ -164,3 +166,24 @@ class TestDistributionShape(unittest.TestCase):
self.assertTrue( self.assertTrue(
self.dist._extend_shape(shape), self.dist._extend_shape(shape),
shape + self.dist.batch_shape + self.dist.event_shape) shape + self.dist.batch_shape + self.dist.event_shape)
class TestDistributionException(unittest.TestCase):
def setUp(self):
self._d = paddle.distribution.Distribution()
def test_mean(self):
with self.assertRaises(NotImplementedError):
self._d.mean
def test_variance(self):
with self.assertRaises(NotImplementedError):
self._d.variance
def test_rsample(self):
with self.assertRaises(NotImplementedError):
self._d.rsample(())
if __name__ == '__main__':
unittest.main()
...@@ -18,14 +18,15 @@ import numpy as np ...@@ -18,14 +18,15 @@ import numpy as np
import paddle import paddle
import scipy.stats import scipy.stats
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place, import config
xrand) from config import ATOL, DEVICES, RTOL
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
@place(DEVICES) @place(DEVICES)
@parameterize((TEST_CASE_NAME, 'alpha', 'beta'), @parameterize_cls((TEST_CASE_NAME, 'alpha', 'beta'),
[('test-scale', 1.0, 2.0), ('test-tensor', xrand(), xrand()), [('test-scale', 1.0, 2.0), ('test-tensor', xrand(), xrand()),
('test-broadcast', xrand((2, 1)), xrand((2, 5)))]) ('test-broadcast', xrand((2, 1)), xrand((2, 5)))])
class TestBeta(unittest.TestCase): class TestBeta(unittest.TestCase):
def setUp(self): def setUp(self):
# scale no need convert to tensor for scale input unittest # scale no need convert to tensor for scale input unittest
...@@ -98,3 +99,7 @@ class TestBeta(unittest.TestCase): ...@@ -98,3 +99,7 @@ class TestBeta(unittest.TestCase):
self.assertTrue( self.assertTrue(
self._paddle_beta.sample(case.get('input')).shape == self._paddle_beta.sample(case.get('input')).shape ==
case.get('expect')) case.get('expect'))
if __name__ == '__main__':
unittest.main()
...@@ -18,16 +18,19 @@ import numpy as np ...@@ -18,16 +18,19 @@ import numpy as np
import paddle import paddle
import scipy.stats import scipy.stats
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place, import config
xrand) import parameterize as param
from config import ATOL, RTOL
from parameterize import xrand
paddle.enable_static() paddle.enable_static()
@place(DEVICES) @param.place(config.DEVICES)
@parameterize((TEST_CASE_NAME, 'alpha', 'beta'), [('test-tensor', xrand( @param.parameterize_cls(
(10, 10)), xrand((10, 10))), ('test-broadcast', xrand((2, 1)), xrand( (param.TEST_CASE_NAME, 'alpha', 'beta'), [('test-tensor', xrand(
(2, 5))), ('test-larger-data', xrand((10, 20)), xrand((10, 20)))]) (10, 10)), xrand((10, 10))), ('test-broadcast', xrand((2, 1)), xrand(
(2, 5))), ('test-larger-data', xrand((10, 20)), xrand((10, 20)))])
class TestBeta(unittest.TestCase): class TestBeta(unittest.TestCase):
def setUp(self): def setUp(self):
self.program = paddle.static.Program() self.program = paddle.static.Program()
......
...@@ -439,3 +439,7 @@ class DistributionTestError(unittest.TestCase): ...@@ -439,3 +439,7 @@ class DistributionTestError(unittest.TestCase):
cat.log_prob(value) cat.log_prob(value)
self.assertRaises(ValueError, test_shape_not_match_error) self.assertRaises(ValueError, test_shape_not_match_error)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.distribution import constraint
import config
import parameterize as param
@param.param_cls((param.TEST_CASE_NAME, 'value'),
[('NotImplement', np.random.rand(2, 3))])
class TestConstraint(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Constraint()
def test_costraint(self):
with self.assertRaises(NotImplementedError):
self._constraint(self.value)
@param.param_cls((param.TEST_CASE_NAME, 'value', 'expect'),
[('real', 1., True)])
class TestReal(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Real()
def test_costraint(self):
self.assertEqual(self._constraint(self.value), self.expect)
@param.param_cls((param.TEST_CASE_NAME, 'lower', 'upper', 'value', 'expect'),
[('in_range', 0, 1, 0.5, True), ('out_range', 0, 1, 2, False)])
class TestRange(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Range(self.lower, self.upper)
def test_costraint(self):
self.assertEqual(self._constraint(self.value), self.expect)
@param.param_cls((param.TEST_CASE_NAME, 'value', 'expect'),
[('positive', 1, True), ('negative', -1, False)])
class TestPositive(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Positive()
def test_costraint(self):
self.assertEqual(self._constraint(self.value), self.expect)
@param.param_cls((param.TEST_CASE_NAME, 'value', 'expect'),
[('simplex', paddle.to_tensor([0.5, 0.5]), True),
('non_simplex', paddle.to_tensor([-0.5, 0.5]), False)])
class TestSimplex(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Simplex()
def test_costraint(self):
self.assertEqual(self._constraint(self.value), self.expect)
if __name__ == '__main__':
unittest.main()
...@@ -19,15 +19,15 @@ import paddle ...@@ -19,15 +19,15 @@ import paddle
import scipy.stats import scipy.stats
import config import config
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place, from config import ATOL, DEVICES, RTOL
xrand) import parameterize as param
@place(DEVICES) @param.place(DEVICES)
@parameterize( @param.param_cls(
(TEST_CASE_NAME, 'concentration'), (param.TEST_CASE_NAME, 'concentration'),
[ [
('test-one-dim', config.xrand((89, ))), ('test-one-dim', param.xrand((89, ))),
# ('test-multi-dim', config.xrand((10, 20, 30))) # ('test-multi-dim', config.xrand((10, 20, 30)))
]) ])
class TestDirichlet(unittest.TestCase): class TestDirichlet(unittest.TestCase):
...@@ -91,14 +91,18 @@ class TestDirichlet(unittest.TestCase): ...@@ -91,14 +91,18 @@ class TestDirichlet(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.all( np.all(
self._paddle_diric._log_normalizer( self._paddle_diric._log_normalizer(
paddle.to_tensor(config.xrand((100, 100, 100)))).numpy() < paddle.to_tensor(param.xrand((100, 100, 100)))).numpy() <
0.0)) 0.0))
@place(DEVICES) @param.place(DEVICES)
@parameterize((TEST_CASE_NAME, 'concentration'), @param.param_cls((param.TEST_CASE_NAME, 'concentration'),
[('test-zero-dim', np.array(1.0))]) [('test-zero-dim', np.array(1.0))])
class TestDirichletException(unittest.TestCase): class TestDirichletException(unittest.TestCase):
def TestInit(self): def TestInit(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
paddle.distribution.Dirichlet( paddle.distribution.Dirichlet(
paddle.squeeze(self.concentration)) paddle.squeeze(self.concentration))
if __name__ == '__main__':
unittest.main()
...@@ -18,15 +18,15 @@ import numpy as np ...@@ -18,15 +18,15 @@ import numpy as np
import paddle import paddle
import scipy.stats import scipy.stats
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place, from config import ATOL, DEVICES, RTOL
xrand) from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
paddle.enable_static() paddle.enable_static()
@place(DEVICES) @place(DEVICES)
@parameterize((TEST_CASE_NAME, 'concentration'), @parameterize_cls((TEST_CASE_NAME, 'concentration'),
[('test-one-dim', np.random.rand(89) + 5.0)]) [('test-one-dim', np.random.rand(89) + 5.0)])
class TestDirichlet(unittest.TestCase): class TestDirichlet(unittest.TestCase):
def setUp(self): def setUp(self):
self.program = paddle.static.Program() self.program = paddle.static.Program()
......
...@@ -20,14 +20,15 @@ import scipy.stats ...@@ -20,14 +20,15 @@ import scipy.stats
import config import config
import mock_data as mock import mock_data as mock
import parameterize
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
@config.parameterize( @parameterize.parameterize_cls(
(config.TEST_CASE_NAME, 'dist'), [('test-mock-exp', (parameterize.TEST_CASE_NAME, 'dist'), [('test-mock-exp',
mock.Exponential(rate=paddle.rand( mock.Exponential(rate=paddle.rand(
[100, 200, 99], [100, 200, 99],
dtype=config.DEFAULT_DTYPE)))]) dtype=config.DEFAULT_DTYPE)))])
class TestExponentialFamily(unittest.TestCase): class TestExponentialFamily(unittest.TestCase):
def test_entropy(self): def test_entropy(self):
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -37,15 +38,15 @@ class TestExponentialFamily(unittest.TestCase): ...@@ -37,15 +38,15 @@ class TestExponentialFamily(unittest.TestCase):
atol=config.ATOL.get(config.DEFAULT_DTYPE)) atol=config.ATOL.get(config.DEFAULT_DTYPE))
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
@config.parameterize( @parameterize.parameterize_cls(
(config.TEST_CASE_NAME, 'dist'), (config.TEST_CASE_NAME, 'dist'),
[('test-dummy', mock.DummyExpFamily(0.5, 0.5)), [('test-dummy', mock.DummyExpFamily(0.5, 0.5)),
('test-dirichlet', ('test-dirichlet',
paddle.distribution.Dirichlet(paddle.to_tensor(config.xrand()))), ( paddle.distribution.Dirichlet(paddle.to_tensor(parameterize.xrand()))), (
'test-beta', paddle.distribution.Beta( 'test-beta', paddle.distribution.Beta(
paddle.to_tensor(config.xrand()), paddle.to_tensor(parameterize.xrand()),
paddle.to_tensor(config.xrand())))]) paddle.to_tensor(parameterize.xrand())))])
class TestExponentialFamilyException(unittest.TestCase): class TestExponentialFamilyException(unittest.TestCase):
def test_entropy_exception(self): def test_entropy_exception(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
......
...@@ -20,17 +20,18 @@ import scipy.stats ...@@ -20,17 +20,18 @@ import scipy.stats
import config import config
import mock_data as mock import mock_data as mock
import parameterize
paddle.enable_static() paddle.enable_static()
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
class TestExponentialFamily(unittest.TestCase): class TestExponentialFamily(unittest.TestCase):
def setUp(self): def setUp(self):
self.program = paddle.static.Program() self.program = paddle.static.Program()
self.executor = paddle.static.Executor() self.executor = paddle.static.Executor()
with paddle.static.program_guard(self.program): with paddle.static.program_guard(self.program):
rate_np = config.xrand((100, 200, 99)) rate_np = parameterize.xrand((100, 200, 99))
rate = paddle.static.data('rate', rate_np.shape, rate_np.dtype) rate = paddle.static.data('rate', rate_np.shape, rate_np.dtype)
self.mock_dist = mock.Exponential(rate) self.mock_dist = mock.Exponential(rate)
self.feeds = {'rate': rate_np} self.feeds = {'rate': rate_np}
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import parameterize as param
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank'),
[('base_beta', paddle.distribution.Beta(
paddle.rand([1, 2]), paddle.rand([1, 2])), 1)])
class TestIndependent(unittest.TestCase):
def setUp(self):
self._t = paddle.distribution.Independent(self.base,
self.reinterpreted_batch_rank)
def test_mean(self):
np.testing.assert_allclose(
self.base.mean,
self._t.mean,
rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))
def test_variance(self):
np.testing.assert_allclose(
self.base.variance,
self._t.variance,
rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))
def test_entropy(self):
np.testing.assert_allclose(
self._np_sum_rightmost(self.base.entropy().numpy(),
self.reinterpreted_batch_rank),
self._t.entropy(),
rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))
def _np_sum_rightmost(self, value, n):
return np.sum(value, tuple(range(-n, 0))) if n > 0 else value
def test_log_prob(self):
value = np.random.rand(1)
np.testing.assert_allclose(
self._np_sum_rightmost(
self.base.log_prob(paddle.to_tensor(value)).numpy(),
self.reinterpreted_batch_rank),
self._t.log_prob(paddle.to_tensor(value)).numpy(),
rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))
# TODO(cxxly): Add Kolmogorov-Smirnov test for sample result.
def test_sample(self):
shape = (5, 10, 8)
expected_shape = (5, 10, 8, 1, 2)
data = self._t.sample(shape)
self.assertEqual(tuple(data.shape), expected_shape)
self.assertEqual(data.dtype, self.base.alpha.dtype)
@param.place(config.DEVICES)
@param.param_cls(
(param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank',
'expected_exception'),
[('base_not_transform', '', 1, TypeError),
('rank_less_than_zero', paddle.distribution.Transform(), -1, ValueError)])
class TestIndependentException(unittest.TestCase):
def test_init(self):
with self.assertRaises(self.expected_exception):
paddle.distribution.IndependentTransform(
self.base, self.reinterpreted_batch_rank)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import parameterize as param
paddle.enable_static()
@param.place(config.DEVICES)
@param.param_cls(
(param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank', 'alpha', 'beta'),
[('base_beta', paddle.distribution.Beta, 1, np.random.rand(1, 2),
np.random.rand(1, 2))])
class TestIndependent(unittest.TestCase):
def setUp(self):
value = np.random.rand(1)
self.dtype = value.dtype
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
alpha = paddle.static.data('alpha', self.alpha.shape,
self.alpha.dtype)
beta = paddle.static.data('beta', self.beta.shape, self.beta.dtype)
self.base = self.base(alpha, beta)
t = paddle.distribution.Independent(self.base,
self.reinterpreted_batch_rank)
mean = t.mean
variance = t.variance
entropy = t.entropy()
static_value = paddle.static.data('value', value.shape, value.dtype)
log_prob = t.log_prob(static_value)
base_mean = self.base.mean
base_variance = self.base.variance
base_entropy = self.base.entropy()
base_log_prob = self.base.log_prob(static_value)
fetch_list = [
mean, variance, entropy, log_prob, base_mean, base_variance,
base_entropy, base_log_prob
]
exe.run(sp)
[
self.mean, self.variance, self.entropy, self.log_prob,
self.base_mean, self.base_variance, self.base_entropy,
self.base_log_prob
] = exe.run(
mp,
feed={'value': value,
'alpha': self.alpha,
'beta': self.beta},
fetch_list=fetch_list)
def test_mean(self):
np.testing.assert_allclose(
self.mean,
self.base_mean,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
def test_variance(self):
np.testing.assert_allclose(
self.variance,
self.base_variance,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
def test_entropy(self):
np.testing.assert_allclose(
self._np_sum_rightmost(self.base_entropy,
self.reinterpreted_batch_rank),
self.entropy,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
def _np_sum_rightmost(self, value, n):
return np.sum(value, tuple(range(-n, 0))) if n > 0 else value
def test_log_prob(self):
np.testing.assert_allclose(
self._np_sum_rightmost(self.base_log_prob,
self.reinterpreted_batch_rank),
self.log_prob,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
if __name__ == '__main__':
unittest.main()
...@@ -19,15 +19,17 @@ import paddle ...@@ -19,15 +19,17 @@ import paddle
import scipy.stats import scipy.stats
import config import config
import parameterize
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [ @parameterize.parameterize_cls(
('one-dim', 10, config.xrand((3, ))), (parameterize.TEST_CASE_NAME, 'total_count', 'probs'), [
('multi-dim', 9, config.xrand((10, 20))), ('one-dim', 10, parameterize.xrand((3, ))),
('prob-sum-one', 10, np.array([0.5, 0.2, 0.3])), ('multi-dim', 9, parameterize.xrand((10, 20))),
('prob-sum-non-one', 10, np.array([2., 3., 5.])), ('prob-sum-one', 10, np.array([0.5, 0.2, 0.3])),
]) ('prob-sum-non-one', 10, np.array([2., 3., 5.])),
])
class TestMultinomial(unittest.TestCase): class TestMultinomial(unittest.TestCase):
def setUp(self): def setUp(self):
self._dist = paddle.distribution.Multinomial( self._dist = paddle.distribution.Multinomial(
...@@ -98,9 +100,9 @@ class TestMultinomial(unittest.TestCase): ...@@ -98,9 +100,9 @@ class TestMultinomial(unittest.TestCase):
return scipy.stats.multinomial.entropy(self.total_count, probs) return scipy.stats.multinomial.entropy(self.total_count, probs)
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
@config.parameterize( @parameterize.parameterize_cls(
(config.TEST_CASE_NAME, 'total_count', 'probs', 'value'), (parameterize.TEST_CASE_NAME, 'total_count', 'probs', 'value'),
[ [
('value-float', 10, np.array([0.2, 0.3, 0.5]), np.array([2., 3., 5.])), ('value-float', 10, np.array([0.2, 0.3, 0.5]), np.array([2., 3., 5.])),
('value-int', 10, np.array([0.2, 0.3, 0.5]), np.array([2, 3, 5])), ('value-int', 10, np.array([0.2, 0.3, 0.5]), np.array([2, 3, 5])),
...@@ -122,12 +124,13 @@ class TestMultinomialPmf(unittest.TestCase): ...@@ -122,12 +124,13 @@ class TestMultinomialPmf(unittest.TestCase):
atol=config.ATOL.get(str(self.probs.dtype))) atol=config.ATOL.get(str(self.probs.dtype)))
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [ @parameterize.parameterize_cls(
('total_count_le_one', 0, np.array([0.3, 0.7])), (config.TEST_CASE_NAME, 'total_count', 'probs'), [
('total_count_float', np.array([0.3, 0.7])), ('total_count_le_one', 0, np.array([0.3, 0.7])),
('probs_zero_dim', np.array(0)), ('total_count_float', np.array([0.3, 0.7])),
]) ('probs_zero_dim', np.array(0)),
])
class TestMultinomialException(unittest.TestCase): class TestMultinomialException(unittest.TestCase):
def TestInit(self): def TestInit(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
...@@ -19,17 +19,19 @@ import paddle ...@@ -19,17 +19,19 @@ import paddle
import scipy.stats import scipy.stats
import config import config
import parameterize
paddle.enable_static() paddle.enable_static()
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [ @parameterize.parameterize_cls(
('one-dim', 5, config.xrand((3, ))), (parameterize.TEST_CASE_NAME, 'total_count', 'probs'), [
('multi-dim', 9, config.xrand((2, 3))), ('one-dim', 5, parameterize.xrand((3, ))),
('prob-sum-one', 5, np.array([0.5, 0.2, 0.3])), ('multi-dim', 9, parameterize.xrand((2, 3))),
('prob-sum-non-one', 5, np.array([2., 3., 5.])), ('prob-sum-one', 5, np.array([0.5, 0.2, 0.3])),
]) ('prob-sum-non-one', 5, np.array([2., 3., 5.])),
])
class TestMultinomial(unittest.TestCase): class TestMultinomial(unittest.TestCase):
def setUp(self): def setUp(self):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -99,9 +101,9 @@ class TestMultinomial(unittest.TestCase): ...@@ -99,9 +101,9 @@ class TestMultinomial(unittest.TestCase):
return scipy.stats.multinomial.entropy(self.total_count, probs) return scipy.stats.multinomial.entropy(self.total_count, probs)
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
@config.parameterize( @parameterize.parameterize_cls(
(config.TEST_CASE_NAME, 'total_count', 'probs', 'value'), (parameterize.TEST_CASE_NAME, 'total_count', 'probs', 'value'),
[ [
('value-float', 5, np.array([0.2, 0.3, 0.5]), np.array([1., 1., 3.])), ('value-float', 5, np.array([0.2, 0.3, 0.5]), np.array([1., 1., 3.])),
('value-int', 5, np.array([0.2, 0.3, 0.5]), np.array([2, 2, 1])), ('value-int', 5, np.array([0.2, 0.3, 0.5]), np.array([2, 2, 1])),
...@@ -139,12 +141,13 @@ class TestMultinomialPmf(unittest.TestCase): ...@@ -139,12 +141,13 @@ class TestMultinomialPmf(unittest.TestCase):
atol=config.ATOL.get(str(self.probs.dtype))) atol=config.ATOL.get(str(self.probs.dtype)))
@config.place(config.DEVICES) @parameterize.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [ @parameterize.parameterize_cls(
('total_count_le_one', 0, np.array([0.3, 0.7])), (parameterize.TEST_CASE_NAME, 'total_count', 'probs'), [
('total_count_float', np.array([0.3, 0.7])), ('total_count_le_one', 0, np.array([0.3, 0.7])),
('probs_zero_dim', np.array(0)), ('total_count_float', np.array([0.3, 0.7])),
]) ('probs_zero_dim', np.array(0)),
])
class TestMultinomialException(unittest.TestCase): class TestMultinomialException(unittest.TestCase):
def setUp(self): def setUp(self):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
......
...@@ -454,3 +454,7 @@ class NormalTest10(NormalTest): ...@@ -454,3 +454,7 @@ class NormalTest10(NormalTest):
with fluid.program_guard(self.test_program): with fluid.program_guard(self.test_program):
self.static_values = layers.data( self.static_values = layers.data(
name='values', shape=[dims], dtype='float32') name='values', shape=[dims], dtype='float32')
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import parameterize as param
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'base', 'transforms'),
[('base_normal', paddle.distribution.Normal(0., 1.),
[paddle.distribution.ExpTransform()])])
class TestIndependent(unittest.TestCase):
def setUp(self):
self._t = paddle.distribution.TransformedDistribution(self.base,
self.transforms)
def _np_sum_rightmost(self, value, n):
return np.sum(value, tuple(range(-n, 0))) if n > 0 else value
def test_log_prob(self):
value = paddle.to_tensor(0.5)
np.testing.assert_allclose(
self.simple_log_prob(value, self.base, self.transforms),
self._t.log_prob(value),
rtol=config.RTOL.get(str(value.numpy().dtype)),
atol=config.ATOL.get(str(value.numpy().dtype)))
def simple_log_prob(self, value, base, transforms):
log_prob = 0.0
y = value
for t in reversed(transforms):
x = t.inverse(y)
log_prob = log_prob - t.forward_log_det_jacobian(x)
y = x
log_prob += base.log_prob(y)
return log_prob
# TODO(cxxly): Add Kolmogorov-Smirnov test for sample result.
def test_sample(self):
shape = [5, 10, 8]
expected_shape = (5, 10, 8)
data = self._t.sample(shape)
self.assertEqual(tuple(data.shape), expected_shape)
self.assertEqual(data.dtype, self.base.loc.dtype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import parameterize as param
paddle.enable_static()
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'base', 'transforms'),
[('base_normal', paddle.distribution.Normal,
[paddle.distribution.ExpTransform()])])
class TestIndependent(unittest.TestCase):
def setUp(self):
value = np.array([0.5])
loc = np.array([0.])
scale = np.array([1.])
shape = [5, 10, 8]
self.dtype = value.dtype
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
static_value = paddle.static.data('value', value.shape, value.dtype)
static_loc = paddle.static.data('loc', loc.shape, loc.dtype)
static_scale = paddle.static.data('scale', scale.shape, scale.dtype)
self.base = self.base(static_loc, static_scale)
self._t = paddle.distribution.TransformedDistribution(
self.base, self.transforms)
actual_log_prob = self._t.log_prob(static_value)
expected_log_prob = self.transformed_log_prob(
static_value, self.base, self.transforms)
sample_data = self._t.sample(shape)
exe.run(sp)
[self.actual_log_prob, self.expected_log_prob,
self.sample_data] = exe.run(
mp,
feed={'value': value,
'loc': loc,
'scale': scale},
fetch_list=[actual_log_prob, expected_log_prob, sample_data])
def test_log_prob(self):
np.testing.assert_allclose(
self.actual_log_prob,
self.expected_log_prob,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
def transformed_log_prob(self, value, base, transforms):
log_prob = 0.0
y = value
for t in reversed(transforms):
x = t.inverse(y)
log_prob = log_prob - t.forward_log_det_jacobian(x)
y = x
log_prob += base.log_prob(y)
return log_prob
# TODO(cxxly): Add Kolmogorov-Smirnov test for sample result.
def test_sample(self):
expected_shape = (5, 10, 8, 1)
self.assertEqual(tuple(self.sample_data.shape), expected_shape)
self.assertEqual(self.sample_data.dtype, self.dtype)
if __name__ == '__main__':
unittest.main()
...@@ -343,3 +343,7 @@ class UniformTestSample2(UniformTestSample): ...@@ -343,3 +343,7 @@ class UniformTestSample2(UniformTestSample):
def init_param(self): def init_param(self):
self.low = -5.0 self.low = -5.0
self.high = 2.0 self.high = 2.0
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.distribution import variable
from paddle.distribution import constraint
import config
import parameterize as param
@param.param_cls(
(param.TEST_CASE_NAME, 'is_discrete', 'event_rank', 'constraint'),
[('NotImplement', False, 0, constraint.Constraint())])
class TestVariable(unittest.TestCase):
def setUp(self):
self._var = variable.Variable(self.is_discrete, self.event_rank,
self.constraint)
@param.param_func([(1, )])
def test_costraint(self, value):
with self.assertRaises(NotImplementedError):
self._var.constraint(value)
@param.param_cls((param.TEST_CASE_NAME, 'base', 'rank'),
[('real_base', variable.real, 10)])
class TestIndependent(unittest.TestCase):
def setUp(self):
self._var = variable.Independent(self.base, self.rank)
@param.param_func([(paddle.rand([2, 3, 4]), ValueError), ])
def test_costraint(self, value, expect):
with self.assertRaises(expect):
self._var.constraint(value)
@param.param_cls((param.TEST_CASE_NAME, 'vars', 'axis'),
[('real_base', [variable.real], 10)])
class TestStack(unittest.TestCase):
def setUp(self):
self._var = variable.Stack(self.vars, self.axis)
def test_is_discrete(self):
self.assertEqual(self._var.is_discrete, False)
@param.param_func([(paddle.rand([2, 3, 4]), ValueError), ])
def test_costraint(self, value, expect):
with self.assertRaises(expect):
self._var.constraint(value)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册