提交 6772bbde 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4021 Add ScalarAffine and Softplus bijector

Merge pull request !4021 from XunDeng/scalar_affine_softplus
...@@ -21,7 +21,13 @@ The high-level components(Bijectors) used to construct the probabilistic network ...@@ -21,7 +21,13 @@ The high-level components(Bijectors) used to construct the probabilistic network
from .bijector import Bijector from .bijector import Bijector
from .power_transform import PowerTransform from .power_transform import PowerTransform
from .exp import Exp from .exp import Exp
from .scalar_affine import ScalarAffine
from .softplus import Softplus
__all__ = ['Bijector', __all__ = [
'PowerTransform', 'Bijector',
'Exp'] 'PowerTransform',
'Exp',
'ScalarAffine',
'Softplus',
]
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Bijector""" """Bijector"""
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from ..distribution import Distribution from ..distribution import Distribution
from ..distribution import TransformedDistribution from ..distribution import TransformedDistribution
...@@ -39,6 +40,9 @@ class Bijector(Cell): ...@@ -39,6 +40,9 @@ class Bijector(Cell):
Constructor of bijector class. Constructor of bijector class.
""" """
super(Bijector, self).__init__() super(Bijector, self).__init__()
validator.check_value_type('name', name, [str], 'Bijector')
validator.check_value_type('is_constant_jacobian', is_constant_jacobian, [bool], name)
validator.check_value_type('is_injective', is_injective, [bool], name)
self._name = name self._name = name
self._dtype = dtype self._dtype = dtype
self._parameters = {} self._parameters = {}
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Scalar Affine Bijector"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor
from .bijector import Bijector
class ScalarAffine(Bijector):
"""
Scalar Affine Bijector.
This Bijector performs the operation: Y = a * X + b, where a is the scale
factor and b is the shift factor.
Args:
scale (float): scale factor. Default: 1.0.
shift (float): shift factor. Default: 0.0.
Examples:
>>> # To initialize a ScalarAffine bijector of scale 1 and shift 2
>>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2)
>>>
>>> # To use ScalarAffine bijector in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.s1 = nn.probability.bijector.ScalarAffine(1, 2)
>>>
>>> def construct(self, value):
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'forward' with the name of the function
>>> ans = self.s1.forward(value)
>>> ans = self.s1.inverse(value)
>>> ans = self.s1.forward_log_jacobian(value)
>>> ans = self.s1.inverse_log_jacobian(value)
"""
def __init__(self,
scale=1.0,
shift=0.0,
name='ScalarAffine'):
"""
Constructor of scalar affine bijector.
"""
param = dict(locals())
validator.check_value_type('scale', scale, [float], name)
validator.check_value_type('shift', shift, [float], name)
self._scale = cast_to_tensor(scale)
self._shift = cast_to_tensor(shift)
super(ScalarAffine, self).__init__(
is_constant_jacobian=True,
is_injective=True,
name=name,
dtype=None,
param=param)
self.log = P.Log()
self.oneslike = P.OnesLike()
@property
def scale(self):
return self._scale
@property
def shift(self):
return self._shift
def extend_repr(self):
str_info = f'scale = {self.scale}, shift = {self.shift}'
return str_info
def shape_mapping(self, shape):
return shape
def _forward(self, x):
r"""
.. math::
f(x) = a * x + b
"""
return self.scale * x + self.shift
def _inverse(self, y):
r"""
.. math::
f(y) = \frac{y - b}{a}
"""
return (y - self.shift) / self.scale
def _forward_log_jacobian(self, value):
r"""
.. math::
f(x) = a * x + b
f'(x) = a
\log(f'(x)) = \log(a)
"""
return self.log(self.scale) * self.oneslike(value)
def _inverse_log_jacobian(self, value):
r"""
.. math::
f(y) = \frac{(y - b)}{a}
f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a)
"""
return -1. * self.log(self.scale) * self.oneslike(value)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Softplus Bijector"""
from mindspore.ops import operations as P
from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor
from .bijector import Bijector
class Softplus(Bijector):
r"""
Softplus Bijector.
This Bijector performs the operation: Y = \frac{\log(1 + e ^ {kX})}{k}, where k is the sharpness factor.
Args:
sharpness (float): scale factor. Default: 1.0.
Examples:
>>> # To initialize a Softplus bijector of sharpness 2
>>> softplus = nn.probability.bijector.Softfplus(2)
>>>
>>> # To use ScalarAffine bijector in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.sp1 = nn.probability.bijector.Softflus(2)
>>>
>>> def construct(self, value):
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'forward' with the name of the function
>>> ans = self.sp1.forward(value)
>>> ans = self.sp1.inverse(value)
>>> ans = self.sp1.forward_log_jacobian(value)
>>> ans = self.sp1.inverse_log_jacobian(value)
"""
def __init__(self,
sharpness=1.0,
name='Softplus'):
param = dict(locals())
validator.check_value_type('sharpness', sharpness, [float], name)
super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness)
self.exp = P.Exp()
self.expm1 = self._expm1_by_step
self.log_sigmoid = LogSigmoid()
self.log = P.Log()
self.sigmoid = P.Sigmoid()
self.softplus = self._softplus
self.inverse_softplus = self._inverse_softplus
def _expm1_by_step(self, x):
"""
Expm1 ops under GPU context.
"""
return self.exp(x) - 1.0
def _softplus(self, x):
return self.log(self.exp(x) + 1.0)
def _inverse_softplus(self, x):
r"""
.. math::
f(x) = \frac{\log(1 + e^{x}))}
f^{-1}(y) = \frac{\log(e^{y} - 1)}
"""
return self.log(self.expm1(x))
@property
def sharpness(self):
return self._sharpness
def extend_repr(self):
str_info = f'sharpness = {self.sharpness}'
return str_info
def shape_mapping(self, shape):
return shape
def _forward(self, x):
scaled_value = self.sharpness * x
return self.softplus(scaled_value) / self.sharpness
def _inverse(self, y):
r"""
.. math::
f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
"""
scaled_value = self.sharpness * y
return self.inverse_softplus(scaled_value) / self.sharpness
def _forward_log_jacobian(self, x):
r"""
.. math:
f(x) = \log(1 + e^{kx}) / k
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
"""
scaled_value = self.sharpness * x
return self.log_sigmoid(scaled_value)
def _inverse_log_jacobian(self, y):
r"""
.. math:
f(y) = \frac{\log(e^{ky} - 1)}{k}
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
"""
scaled_value = self.sharpness * y
return scaled_value - self.inverse_softplus(scaled_value)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for scalar affine"""
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: forward pass of bijector.
"""
def __init__(self):
super(Net, self).__init__()
self.bijector = msb.ScalarAffine(scale=2.0, shift=1.0)
def construct(self, x_):
return self.bijector.forward(x_)
def test_forward():
forward = Net()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = forward(Tensor(x, dtype=dtype.float32))
tol = 1e-6
expected = 2 * x + 1
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net1(nn.Cell):
"""
Test class: backward pass of bijector.
"""
def __init__(self):
super(Net1, self).__init__()
self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0)
def construct(self, x_):
return self.bijector.inverse(x_)
def test_backward():
backward = Net1()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = backward(Tensor(x, dtype=dtype.float32))
tol = 1e-6
expected = 0.5 * (x - 1.0)
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net2(nn.Cell):
"""
Test class: Forward Jacobian.
"""
def __init__(self):
super(Net2, self).__init__()
self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0)
def construct(self, x_):
return self.bijector.forward_log_jacobian(x_)
def test_forward_jacobian():
forward_jacobian = Net2()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = forward_jacobian(x)
expected = np.log([2.0, 2.0, 2.0, 2.0])
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net3(nn.Cell):
"""
Test class: Backward Jacobian.
"""
def __init__(self):
super(Net3, self).__init__()
self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0)
def construct(self, x_):
return self.bijector.inverse_log_jacobian(x_)
def test_backward_jacobian():
backward_jacobian = Net3()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = backward_jacobian(x)
expected = np.log([0.5, 0.5, 0.5, 0.5])
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for scalar affine"""
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
context.set_context(device_target="Ascend")
class Net(nn.Cell):
"""
Test class: forward pass of bijector.
"""
def __init__(self):
super(Net, self).__init__()
self.bijector = msb.Softplus(sharpness=2.0)
def construct(self, x_):
return self.bijector.forward(x_)
def test_forward():
forward = Net()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = forward(Tensor(x, dtype=dtype.float32))
expected = np.log(1 + np.exp(2 * x)) * 0.5
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net1(nn.Cell):
"""
Test class: backward pass of bijector.
"""
def __init__(self):
super(Net1, self).__init__()
self.bijector = msb.Softplus(sharpness=2.0)
def construct(self, x_):
return self.bijector.inverse(x_)
def test_backward():
backward = Net1()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = backward(Tensor(x, dtype=dtype.float32))
expected = np.log(np.exp(2 * x) - 1) * 0.5
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net2(nn.Cell):
"""
Test class: Forward Jacobian.
"""
def __init__(self):
super(Net2, self).__init__()
self.bijector = msb.Softplus(sharpness=2.0)
def construct(self, x_):
return self.bijector.forward_log_jacobian(x_)
def test_forward_jacobian():
forward_jacobian = Net2()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = forward_jacobian(Tensor(x, dtype=dtype.float32))
expected = np.log(np.exp(2 * x) / (1 + np.exp(2.0 * x)))
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net3(nn.Cell):
"""
Test class: Backward Jacobian.
"""
def __init__(self):
super(Net3, self).__init__()
self.bijector = msb.Softplus(sharpness=2.0)
def construct(self, x_):
return self.bijector.inverse_log_jacobian(x_)
def test_backward_jacobian():
backward_jacobian = Net3()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = backward_jacobian(Tensor(x, dtype=dtype.float32))
expected = np.log(np.exp(2.0 * x) / np.expm1(2.0 * x))
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test cases for exp""" """test cases for exp"""
import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb import mindspore.nn.probability.bijector as msb
from mindspore import Tensor from mindspore import Tensor
...@@ -21,8 +22,10 @@ from mindspore import dtype ...@@ -21,8 +22,10 @@ from mindspore import dtype
def test_init(): def test_init():
b = msb.Exp() b = msb.Exp()
assert isinstance(b, msb.Bijector) assert isinstance(b, msb.Bijector)
b = msb.Exp(1.0)
assert isinstance(b, msb.Bijector) def test_type():
with pytest.raises(TypeError):
msb.Exp(name=0.1)
class Net(nn.Cell): class Net(nn.Cell):
""" """
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test cases for powertransform""" """test cases for powertransform"""
import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb import mindspore.nn.probability.bijector as msb
from mindspore import Tensor from mindspore import Tensor
...@@ -24,6 +25,12 @@ def test_init(): ...@@ -24,6 +25,12 @@ def test_init():
b = msb.PowerTransform(1) b = msb.PowerTransform(1)
assert isinstance(b, msb.Bijector) assert isinstance(b, msb.Bijector)
def test_type():
with pytest.raises(TypeError):
msb.PowerTransform(power='power')
with pytest.raises(TypeError):
msb.PowerTransform(name=0.1)
class Net(nn.Cell): class Net(nn.Cell):
""" """
Test class: forward and inverse pass of bijector. Test class: forward and inverse pass of bijector.
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for scalar affine"""
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
def test_init():
"""
Test initializations.
"""
b = msb.ScalarAffine()
assert isinstance(b, msb.Bijector)
b = msb.ScalarAffine(scale=1.0)
assert isinstance(b, msb.Bijector)
b = msb.ScalarAffine(shift=2.0)
assert isinstance(b, msb.Bijector)
b = msb.ScalarAffine(3.0, 4.0)
assert isinstance(b, msb.Bijector)
def test_type():
with pytest.raises(TypeError):
msb.ScalarAffine(scale='scale')
with pytest.raises(TypeError):
msb.ScalarAffine(shift='shift')
with pytest.raises(TypeError):
msb.ScalarAffine(name=0.1)
class ForwardBackward(nn.Cell):
"""
Test class: forward and backward pass.
"""
def __init__(self):
super(ForwardBackward, self).__init__()
self.b1 = msb.ScalarAffine(2.0, 1.0)
self.b2 = msb.ScalarAffine()
def construct(self, x_):
ans1 = self.b1.inverse(self.b1.forward(x_))
ans2 = self.b2.inverse(self.b2.forward(x_))
return ans1 + ans2
def test_forward_and_backward_pass():
"""
Test forward and backward pass of ScalarAffine bijector.
"""
net = ForwardBackward()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class ForwardJacobian(nn.Cell):
"""
Test class: Forward log Jacobian.
"""
def __init__(self):
super(ForwardJacobian, self).__init__()
self.b1 = msb.ScalarAffine(2.0, 1.0)
self.b2 = msb.ScalarAffine()
def construct(self, x_):
ans1 = self.b1.forward_log_jacobian(x_)
ans2 = self.b2.forward_log_jacobian(x_)
return ans1 + ans2
def test_forward_jacobian():
"""
Test forward log jacobian of ScalarAffine bijector.
"""
net = ForwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class BackwardJacobian(nn.Cell):
"""
Test class: Backward log Jacobian.
"""
def __init__(self):
super(BackwardJacobian, self).__init__()
self.b1 = msb.ScalarAffine(2.0, 1.0)
self.b2 = msb.ScalarAffine()
def construct(self, x_):
ans1 = self.b1.inverse_log_jacobian(x_)
ans2 = self.b2.inverse_log_jacobian(x_)
return ans1 + ans2
def test_backward_jacobian():
"""
Test backward log jacobian of ScalarAffine bijector.
"""
net = BackwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class Net(nn.Cell):
"""
Test class: function calls going through construct.
"""
def __init__(self):
super(Net, self).__init__()
self.b1 = msb.ScalarAffine(1.0, 0.0)
self.b2 = msb.ScalarAffine()
def construct(self, x_):
ans1 = self.b1('inverse', self.b1('forward', x_))
ans2 = self.b2('inverse', self.b2('forward', x_))
ans3 = self.b1('forward_log_jacobian', x_)
ans4 = self.b2('forward_log_jacobian', x_)
ans5 = self.b1('inverse_log_jacobian', x_)
ans6 = self.b2('inverse_log_jacobian', x_)
return ans1 - ans2 + ans3 -ans4 + ans5 - ans6
def test_old_api():
"""
Test old api which goes through construct.
"""
net = Net()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for scalar affine"""
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
def test_init():
"""
Test initializations.
"""
b = msb.Softplus()
assert isinstance(b, msb.Bijector)
b = msb.Softplus(1.0)
assert isinstance(b, msb.Bijector)
def test_type():
with pytest.raises(TypeError):
msb.Softplus(sharpness='sharpness')
with pytest.raises(TypeError):
msb.Softplus(name=0.1)
class ForwardBackward(nn.Cell):
"""
Test class: forward and backward pass.
"""
def __init__(self):
super(ForwardBackward, self).__init__()
self.b1 = msb.Softplus(2.0)
self.b2 = msb.Softplus()
def construct(self, x_):
ans1 = self.b1.inverse(self.b1.forward(x_))
ans2 = self.b2.inverse(self.b2.forward(x_))
return ans1 + ans2
def test_forward_and_backward_pass():
"""
Test forward and backward pass of Softplus bijector.
"""
net = ForwardBackward()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class ForwardJacobian(nn.Cell):
"""
Test class: Forward log Jacobian.
"""
def __init__(self):
super(ForwardJacobian, self).__init__()
self.b1 = msb.Softplus(2.0)
self.b2 = msb.Softplus()
def construct(self, x_):
ans1 = self.b1.forward_log_jacobian(x_)
ans2 = self.b2.forward_log_jacobian(x_)
return ans1 + ans2
def test_forward_jacobian():
"""
Test forward log jacobian of Softplus bijector.
"""
net = ForwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class BackwardJacobian(nn.Cell):
"""
Test class: Backward log Jacobian.
"""
def __init__(self):
super(BackwardJacobian, self).__init__()
self.b1 = msb.Softplus(2.0)
self.b2 = msb.Softplus()
def construct(self, x_):
ans1 = self.b1.inverse_log_jacobian(x_)
ans2 = self.b2.inverse_log_jacobian(x_)
return ans1 + ans2
def test_backward_jacobian():
"""
Test backward log jacobian of Softplus bijector.
"""
net = BackwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class Net(nn.Cell):
"""
Test class: function calls going through construct.
"""
def __init__(self):
super(Net, self).__init__()
self.b1 = msb.Softplus(1.0)
self.b2 = msb.Softplus()
def construct(self, x_):
ans1 = self.b1('inverse', self.b1('forward', x_))
ans2 = self.b2('inverse', self.b2('forward', x_))
ans3 = self.b1('forward_log_jacobian', x_)
ans4 = self.b2('forward_log_jacobian', x_)
ans5 = self.b1('inverse_log_jacobian', x_)
ans6 = self.b2('inverse_log_jacobian', x_)
return ans1 - ans2 + ans3 -ans4 + ans5 - ans6
def test_old_api():
"""
Test old api which goes through construct.
"""
net = Net()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册