diff --git a/mindspore/nn/probability/bijector/scalar_affine.py b/mindspore/nn/probability/bijector/scalar_affine.py index b75298be20faf6ecc7899c5e673969988b6a1aeb..e765187427d9b66dc744e02dc19a5d7825f83f3d 100644 --- a/mindspore/nn/probability/bijector/scalar_affine.py +++ b/mindspore/nn/probability/bijector/scalar_affine.py @@ -69,6 +69,7 @@ class ScalarAffine(Bijector): param=param) self.abs = P.Abs() + self.oneslike = P.OnesLike() self.log = log_generic @property @@ -92,7 +93,7 @@ class ScalarAffine(Bijector): f(x) = a * x + b """ x = self._check_value(x, 'value') - return self.scale * x + self.shift + return self.scale * x + self.shift * self.oneslike(x) def _inverse(self, y): r"""