diff --git a/mindspore/nn/probability/bijector/__init__.py b/mindspore/nn/probability/bijector/__init__.py index 3108742aeae4a1edf4723d5691d55a111008e22f..8caf4e428f58243bab667fafa6823f748d9982a9 100644 --- a/mindspore/nn/probability/bijector/__init__.py +++ b/mindspore/nn/probability/bijector/__init__.py @@ -21,7 +21,13 @@ The high-level components(Bijectors) used to construct the probabilistic network from .bijector import Bijector from .power_transform import PowerTransform from .exp import Exp +from .scalar_affine import ScalarAffine +from .softplus import Softplus -__all__ = ['Bijector', - 'PowerTransform', - 'Exp'] +__all__ = [ + 'Bijector', + 'PowerTransform', + 'Exp', + 'ScalarAffine', + 'Softplus', +] diff --git a/mindspore/nn/probability/bijector/bijector.py b/mindspore/nn/probability/bijector/bijector.py index e6530b9f7026d7cc1acf8ad3b30ae5ba053de8e2..ac011fda33274dc3d9b0556e94bec6a88bf8f4f8 100644 --- a/mindspore/nn/probability/bijector/bijector.py +++ b/mindspore/nn/probability/bijector/bijector.py @@ -14,6 +14,7 @@ # ============================================================================ """Bijector""" from mindspore.nn.cell import Cell +from mindspore._checkparam import Validator as validator from ..distribution import Distribution from ..distribution import TransformedDistribution @@ -39,6 +40,9 @@ class Bijector(Cell): Constructor of bijector class. """ super(Bijector, self).__init__() + validator.check_value_type('name', name, [str], 'Bijector') + validator.check_value_type('is_constant_jacobian', is_constant_jacobian, [bool], name) + validator.check_value_type('is_injective', is_injective, [bool], name) self._name = name self._dtype = dtype self._parameters = {} diff --git a/mindspore/nn/probability/bijector/scalar_affine.py b/mindspore/nn/probability/bijector/scalar_affine.py new file mode 100644 index 0000000000000000000000000000000000000000..b48df1f0a7e0460f30657aa18b7295326d494332 --- /dev/null +++ b/mindspore/nn/probability/bijector/scalar_affine.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Scalar Affine Bijector""" +from mindspore.ops import operations as P +from mindspore._checkparam import Validator as validator +from ..distribution._utils.utils import cast_to_tensor +from .bijector import Bijector + +class ScalarAffine(Bijector): + """ + Scalar Affine Bijector. + This Bijector performs the operation: Y = a * X + b, where a is the scale + factor and b is the shift factor. + + Args: + scale (float): scale factor. Default: 1.0. + shift (float): shift factor. Default: 0.0. + + Examples: + >>> # To initialize a ScalarAffine bijector of scale 1 and shift 2 + >>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2) + >>> + >>> # To use ScalarAffine bijector in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.s1 = nn.probability.bijector.ScalarAffine(1, 2) + >>> + >>> def construct(self, value): + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'forward' with the name of the function + >>> ans = self.s1.forward(value) + >>> ans = self.s1.inverse(value) + >>> ans = self.s1.forward_log_jacobian(value) + >>> ans = self.s1.inverse_log_jacobian(value) + """ + def __init__(self, + scale=1.0, + shift=0.0, + name='ScalarAffine'): + """ + Constructor of scalar affine bijector. + """ + param = dict(locals()) + validator.check_value_type('scale', scale, [float], name) + validator.check_value_type('shift', shift, [float], name) + self._scale = cast_to_tensor(scale) + self._shift = cast_to_tensor(shift) + super(ScalarAffine, self).__init__( + is_constant_jacobian=True, + is_injective=True, + name=name, + dtype=None, + param=param) + + self.log = P.Log() + self.oneslike = P.OnesLike() + + @property + def scale(self): + return self._scale + + @property + def shift(self): + return self._shift + + def extend_repr(self): + str_info = f'scale = {self.scale}, shift = {self.shift}' + return str_info + + def shape_mapping(self, shape): + return shape + + def _forward(self, x): + r""" + .. math:: + f(x) = a * x + b + """ + return self.scale * x + self.shift + + def _inverse(self, y): + r""" + .. math:: + f(y) = \frac{y - b}{a} + """ + return (y - self.shift) / self.scale + + def _forward_log_jacobian(self, value): + r""" + .. math:: + f(x) = a * x + b + f'(x) = a + \log(f'(x)) = \log(a) + """ + return self.log(self.scale) * self.oneslike(value) + + def _inverse_log_jacobian(self, value): + r""" + .. math:: + f(y) = \frac{(y - b)}{a} + f'(x) = \frac{1.0}{a} + \log(f'(x)) = - \log(a) + """ + return -1. * self.log(self.scale) * self.oneslike(value) diff --git a/mindspore/nn/probability/bijector/softplus.py b/mindspore/nn/probability/bijector/softplus.py new file mode 100644 index 0000000000000000000000000000000000000000..26f70c8fc7e906d35ee634708a936162e40c915f --- /dev/null +++ b/mindspore/nn/probability/bijector/softplus.py @@ -0,0 +1,124 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Softplus Bijector""" +from mindspore.ops import operations as P +from mindspore.nn.layer.activation import LogSigmoid +from mindspore._checkparam import Validator as validator +from ..distribution._utils.utils import cast_to_tensor +from .bijector import Bijector + +class Softplus(Bijector): + r""" + Softplus Bijector. + This Bijector performs the operation: Y = \frac{\log(1 + e ^ {kX})}{k}, where k is the sharpness factor. + + Args: + sharpness (float): scale factor. Default: 1.0. + + Examples: + >>> # To initialize a Softplus bijector of sharpness 2 + >>> softplus = nn.probability.bijector.Softfplus(2) + >>> + >>> # To use ScalarAffine bijector in a network + >>> class net(Cell): + >>> def __init__(self): + >>> super(net, self).__init__(): + >>> self.sp1 = nn.probability.bijector.Softflus(2) + >>> + >>> def construct(self, value): + >>> # Similar calls can be made to other probability functions + >>> # by replacing 'forward' with the name of the function + >>> ans = self.sp1.forward(value) + >>> ans = self.sp1.inverse(value) + >>> ans = self.sp1.forward_log_jacobian(value) + >>> ans = self.sp1.inverse_log_jacobian(value) + """ + def __init__(self, + sharpness=1.0, + name='Softplus'): + param = dict(locals()) + validator.check_value_type('sharpness', sharpness, [float], name) + super(Softplus, self).__init__(name=name, param=param) + self._sharpness = cast_to_tensor(sharpness) + + self.exp = P.Exp() + self.expm1 = self._expm1_by_step + self.log_sigmoid = LogSigmoid() + self.log = P.Log() + self.sigmoid = P.Sigmoid() + + self.softplus = self._softplus + self.inverse_softplus = self._inverse_softplus + + def _expm1_by_step(self, x): + """ + Expm1 ops under GPU context. + """ + return self.exp(x) - 1.0 + + def _softplus(self, x): + return self.log(self.exp(x) + 1.0) + + def _inverse_softplus(self, x): + r""" + .. math:: + f(x) = \frac{\log(1 + e^{x}))} + f^{-1}(y) = \frac{\log(e^{y} - 1)} + """ + return self.log(self.expm1(x)) + + @property + def sharpness(self): + return self._sharpness + + def extend_repr(self): + str_info = f'sharpness = {self.sharpness}' + return str_info + + def shape_mapping(self, shape): + return shape + + def _forward(self, x): + scaled_value = self.sharpness * x + return self.softplus(scaled_value) / self.sharpness + + def _inverse(self, y): + r""" + .. math:: + f(x) = \frac{\log(1 + e^{kx}))}{k} + f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k} + """ + scaled_value = self.sharpness * y + return self.inverse_softplus(scaled_value) / self.sharpness + + def _forward_log_jacobian(self, x): + r""" + .. math: + f(x) = \log(1 + e^{kx}) / k + f'(x) = \frac{e^{kx}}{ 1 + e^{kx}} + \log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx) + """ + scaled_value = self.sharpness * x + return self.log_sigmoid(scaled_value) + + def _inverse_log_jacobian(self, y): + r""" + .. math: + f(y) = \frac{\log(e^{ky} - 1)}{k} + f'(y) = \frac{e^{ky}}{e^{ky} - 1} + \log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky) + """ + scaled_value = self.sharpness * y + return scaled_value - self.inverse_softplus(scaled_value) diff --git a/tests/st/ops/ascend/test_bijector/test_scalar_affine.py b/tests/st/ops/ascend/test_bijector/test_scalar_affine.py new file mode 100644 index 0000000000000000000000000000000000000000..137f5e0f056dc68ad1de4b08d88d4f5ecf505c2e --- /dev/null +++ b/tests/st/ops/ascend/test_bijector/test_scalar_affine.py @@ -0,0 +1,99 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for scalar affine""" +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +from mindspore import Tensor +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(nn.Cell): + """ + Test class: forward pass of bijector. + """ + def __init__(self): + super(Net, self).__init__() + self.bijector = msb.ScalarAffine(scale=2.0, shift=1.0) + + def construct(self, x_): + return self.bijector.forward(x_) + +def test_forward(): + forward = Net() + x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + ans = forward(Tensor(x, dtype=dtype.float32)) + tol = 1e-6 + expected = 2 * x + 1 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net1(nn.Cell): + """ + Test class: backward pass of bijector. + """ + def __init__(self): + super(Net1, self).__init__() + self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0) + + def construct(self, x_): + return self.bijector.inverse(x_) + +def test_backward(): + backward = Net1() + x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + ans = backward(Tensor(x, dtype=dtype.float32)) + tol = 1e-6 + expected = 0.5 * (x - 1.0) + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net2(nn.Cell): + """ + Test class: Forward Jacobian. + """ + def __init__(self): + super(Net2, self).__init__() + self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0) + + def construct(self, x_): + return self.bijector.forward_log_jacobian(x_) + +def test_forward_jacobian(): + forward_jacobian = Net2() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = forward_jacobian(x) + expected = np.log([2.0, 2.0, 2.0, 2.0]) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net3(nn.Cell): + """ + Test class: Backward Jacobian. + """ + def __init__(self): + super(Net3, self).__init__() + self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0) + + def construct(self, x_): + return self.bijector.inverse_log_jacobian(x_) + +def test_backward_jacobian(): + backward_jacobian = Net3() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = backward_jacobian(x) + expected = np.log([0.5, 0.5, 0.5, 0.5]) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() diff --git a/tests/st/ops/ascend/test_bijector/test_softplus.py b/tests/st/ops/ascend/test_bijector/test_softplus.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf33aa2549096acfa3baf2669141685137f3815 --- /dev/null +++ b/tests/st/ops/ascend/test_bijector/test_softplus.py @@ -0,0 +1,99 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for scalar affine""" +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +from mindspore import Tensor +from mindspore import dtype + +context.set_context(device_target="Ascend") + +class Net(nn.Cell): + """ + Test class: forward pass of bijector. + """ + def __init__(self): + super(Net, self).__init__() + self.bijector = msb.Softplus(sharpness=2.0) + + def construct(self, x_): + return self.bijector.forward(x_) + +def test_forward(): + forward = Net() + x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + ans = forward(Tensor(x, dtype=dtype.float32)) + expected = np.log(1 + np.exp(2 * x)) * 0.5 + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net1(nn.Cell): + """ + Test class: backward pass of bijector. + """ + def __init__(self): + super(Net1, self).__init__() + self.bijector = msb.Softplus(sharpness=2.0) + + def construct(self, x_): + return self.bijector.inverse(x_) + +def test_backward(): + backward = Net1() + x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + ans = backward(Tensor(x, dtype=dtype.float32)) + expected = np.log(np.exp(2 * x) - 1) * 0.5 + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net2(nn.Cell): + """ + Test class: Forward Jacobian. + """ + def __init__(self): + super(Net2, self).__init__() + self.bijector = msb.Softplus(sharpness=2.0) + + def construct(self, x_): + return self.bijector.forward_log_jacobian(x_) + +def test_forward_jacobian(): + forward_jacobian = Net2() + x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + ans = forward_jacobian(Tensor(x, dtype=dtype.float32)) + expected = np.log(np.exp(2 * x) / (1 + np.exp(2.0 * x))) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() + +class Net3(nn.Cell): + """ + Test class: Backward Jacobian. + """ + def __init__(self): + super(Net3, self).__init__() + self.bijector = msb.Softplus(sharpness=2.0) + + def construct(self, x_): + return self.bijector.inverse_log_jacobian(x_) + +def test_backward_jacobian(): + backward_jacobian = Net3() + x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + ans = backward_jacobian(Tensor(x, dtype=dtype.float32)) + expected = np.log(np.exp(2.0 * x) / np.expm1(2.0 * x)) + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expected) < tol).all() diff --git a/tests/ut/python/nn/bijector/test_exp.py b/tests/ut/python/nn/bijector/test_exp.py index 13e3e09a34c574933158cf5d2ca50c32305bc5b0..98f6315a5fb1ecf36cb3527d9b9f8f50668580c0 100644 --- a/tests/ut/python/nn/bijector/test_exp.py +++ b/tests/ut/python/nn/bijector/test_exp.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """test cases for exp""" +import pytest import mindspore.nn as nn import mindspore.nn.probability.bijector as msb from mindspore import Tensor @@ -21,8 +22,10 @@ from mindspore import dtype def test_init(): b = msb.Exp() assert isinstance(b, msb.Bijector) - b = msb.Exp(1.0) - assert isinstance(b, msb.Bijector) + +def test_type(): + with pytest.raises(TypeError): + msb.Exp(name=0.1) class Net(nn.Cell): """ diff --git a/tests/ut/python/nn/bijector/test_power_transform.py b/tests/ut/python/nn/bijector/test_power_transform.py index 50ea5dbd44c3b0180fb37de61062f4e5a6e5ddda..e5ed5cd2713fafdc47273791a4d6d9e4b3a99d4b 100644 --- a/tests/ut/python/nn/bijector/test_power_transform.py +++ b/tests/ut/python/nn/bijector/test_power_transform.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """test cases for powertransform""" +import pytest import mindspore.nn as nn import mindspore.nn.probability.bijector as msb from mindspore import Tensor @@ -24,6 +25,12 @@ def test_init(): b = msb.PowerTransform(1) assert isinstance(b, msb.Bijector) +def test_type(): + with pytest.raises(TypeError): + msb.PowerTransform(power='power') + with pytest.raises(TypeError): + msb.PowerTransform(name=0.1) + class Net(nn.Cell): """ Test class: forward and inverse pass of bijector. diff --git a/tests/ut/python/nn/bijector/test_scalar_affine.py b/tests/ut/python/nn/bijector/test_scalar_affine.py new file mode 100644 index 0000000000000000000000000000000000000000..eab0946c7e2e68f73001c9a6bea21ef2305bb2af --- /dev/null +++ b/tests/ut/python/nn/bijector/test_scalar_affine.py @@ -0,0 +1,139 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for scalar affine""" +import pytest +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +from mindspore import Tensor +from mindspore import dtype + +def test_init(): + """ + Test initializations. + """ + b = msb.ScalarAffine() + assert isinstance(b, msb.Bijector) + b = msb.ScalarAffine(scale=1.0) + assert isinstance(b, msb.Bijector) + b = msb.ScalarAffine(shift=2.0) + assert isinstance(b, msb.Bijector) + b = msb.ScalarAffine(3.0, 4.0) + assert isinstance(b, msb.Bijector) + +def test_type(): + with pytest.raises(TypeError): + msb.ScalarAffine(scale='scale') + with pytest.raises(TypeError): + msb.ScalarAffine(shift='shift') + with pytest.raises(TypeError): + msb.ScalarAffine(name=0.1) + +class ForwardBackward(nn.Cell): + """ + Test class: forward and backward pass. + """ + def __init__(self): + super(ForwardBackward, self).__init__() + self.b1 = msb.ScalarAffine(2.0, 1.0) + self.b2 = msb.ScalarAffine() + + def construct(self, x_): + ans1 = self.b1.inverse(self.b1.forward(x_)) + ans2 = self.b2.inverse(self.b2.forward(x_)) + return ans1 + ans2 + +def test_forward_and_backward_pass(): + """ + Test forward and backward pass of ScalarAffine bijector. + """ + net = ForwardBackward() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) + +class ForwardJacobian(nn.Cell): + """ + Test class: Forward log Jacobian. + """ + def __init__(self): + super(ForwardJacobian, self).__init__() + self.b1 = msb.ScalarAffine(2.0, 1.0) + self.b2 = msb.ScalarAffine() + + def construct(self, x_): + ans1 = self.b1.forward_log_jacobian(x_) + ans2 = self.b2.forward_log_jacobian(x_) + return ans1 + ans2 + +def test_forward_jacobian(): + """ + Test forward log jacobian of ScalarAffine bijector. + """ + net = ForwardJacobian() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) + + +class BackwardJacobian(nn.Cell): + """ + Test class: Backward log Jacobian. + """ + def __init__(self): + super(BackwardJacobian, self).__init__() + self.b1 = msb.ScalarAffine(2.0, 1.0) + self.b2 = msb.ScalarAffine() + + def construct(self, x_): + ans1 = self.b1.inverse_log_jacobian(x_) + ans2 = self.b2.inverse_log_jacobian(x_) + return ans1 + ans2 + +def test_backward_jacobian(): + """ + Test backward log jacobian of ScalarAffine bijector. + """ + net = BackwardJacobian() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) + + +class Net(nn.Cell): + """ + Test class: function calls going through construct. + """ + def __init__(self): + super(Net, self).__init__() + self.b1 = msb.ScalarAffine(1.0, 0.0) + self.b2 = msb.ScalarAffine() + + def construct(self, x_): + ans1 = self.b1('inverse', self.b1('forward', x_)) + ans2 = self.b2('inverse', self.b2('forward', x_)) + ans3 = self.b1('forward_log_jacobian', x_) + ans4 = self.b2('forward_log_jacobian', x_) + ans5 = self.b1('inverse_log_jacobian', x_) + ans6 = self.b2('inverse_log_jacobian', x_) + return ans1 - ans2 + ans3 -ans4 + ans5 - ans6 + +def test_old_api(): + """ + Test old api which goes through construct. + """ + net = Net() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) diff --git a/tests/ut/python/nn/bijector/test_softplus.py b/tests/ut/python/nn/bijector/test_softplus.py new file mode 100644 index 0000000000000000000000000000000000000000..4255751a78839b3ad4732f3dfabc762f73df0be0 --- /dev/null +++ b/tests/ut/python/nn/bijector/test_softplus.py @@ -0,0 +1,133 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for scalar affine""" +import pytest +import mindspore.nn as nn +import mindspore.nn.probability.bijector as msb +from mindspore import Tensor +from mindspore import dtype + +def test_init(): + """ + Test initializations. + """ + b = msb.Softplus() + assert isinstance(b, msb.Bijector) + b = msb.Softplus(1.0) + assert isinstance(b, msb.Bijector) + +def test_type(): + with pytest.raises(TypeError): + msb.Softplus(sharpness='sharpness') + with pytest.raises(TypeError): + msb.Softplus(name=0.1) + +class ForwardBackward(nn.Cell): + """ + Test class: forward and backward pass. + """ + def __init__(self): + super(ForwardBackward, self).__init__() + self.b1 = msb.Softplus(2.0) + self.b2 = msb.Softplus() + + def construct(self, x_): + ans1 = self.b1.inverse(self.b1.forward(x_)) + ans2 = self.b2.inverse(self.b2.forward(x_)) + return ans1 + ans2 + +def test_forward_and_backward_pass(): + """ + Test forward and backward pass of Softplus bijector. + """ + net = ForwardBackward() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) + +class ForwardJacobian(nn.Cell): + """ + Test class: Forward log Jacobian. + """ + def __init__(self): + super(ForwardJacobian, self).__init__() + self.b1 = msb.Softplus(2.0) + self.b2 = msb.Softplus() + + def construct(self, x_): + ans1 = self.b1.forward_log_jacobian(x_) + ans2 = self.b2.forward_log_jacobian(x_) + return ans1 + ans2 + +def test_forward_jacobian(): + """ + Test forward log jacobian of Softplus bijector. + """ + net = ForwardJacobian() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) + + +class BackwardJacobian(nn.Cell): + """ + Test class: Backward log Jacobian. + """ + def __init__(self): + super(BackwardJacobian, self).__init__() + self.b1 = msb.Softplus(2.0) + self.b2 = msb.Softplus() + + def construct(self, x_): + ans1 = self.b1.inverse_log_jacobian(x_) + ans2 = self.b2.inverse_log_jacobian(x_) + return ans1 + ans2 + +def test_backward_jacobian(): + """ + Test backward log jacobian of Softplus bijector. + """ + net = BackwardJacobian() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor) + + +class Net(nn.Cell): + """ + Test class: function calls going through construct. + """ + def __init__(self): + super(Net, self).__init__() + self.b1 = msb.Softplus(1.0) + self.b2 = msb.Softplus() + + def construct(self, x_): + ans1 = self.b1('inverse', self.b1('forward', x_)) + ans2 = self.b2('inverse', self.b2('forward', x_)) + ans3 = self.b1('forward_log_jacobian', x_) + ans4 = self.b2('forward_log_jacobian', x_) + ans5 = self.b1('inverse_log_jacobian', x_) + ans6 = self.b2('inverse_log_jacobian', x_) + return ans1 - ans2 + ans3 -ans4 + ans5 - ans6 + +def test_old_api(): + """ + Test old api which goes through construct. + """ + net = Net() + x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32) + ans = net(x) + assert isinstance(ans, Tensor)