提交 b366608a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4952 Fix errors in log calculation logics

Merge pull request !4952 from peixu_ren/custom_pp_ops
......@@ -15,24 +15,30 @@
"""Utitly functions to help distribution class."""
import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
def log_by_step(input_x):
"""
Log op on Ascend is calculated as log(abs(x)).
Fix this with putting negative values as nan.
"""
select = P.Select()
log = P.Log()
less = P.Less()
lessequal = P.LessEqual()
fill = P.Fill()
cast = P.Cast()
dtype = P.DType()
shape = P.Shape()
select = P.Select()
input_x = cast(input_x, mstype.float32)
nan = fill(dtype(input_x), shape(input_x), np.nan)
inf = fill(dtype(input_x), shape(input_x), np.inf)
neg_x = less(input_x, 0.0)
nonpos_x = lessequal(input_x, 0.0)
log_x = log(input_x)
nan = fill(dtype(input_x), shape(input_x), np.nan)
result = select(nonpos_x, nan, log_x)
return result
result = select(nonpos_x, -inf, log_x)
return select(neg_x, nan, result)
def log1p_by_step(x):
"""
......
......@@ -157,51 +157,127 @@ def test_cross_entropy():
ans = net(probs_b, probs_a)
assert isinstance(ans, Tensor)
class BernoulliBasics(nn.Cell):
class BernoulliConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def __init__(self):
super(BernoulliConstruct, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
self.b1 = msd.Bernoulli(dtype=dtype.int32)
def construct(self, value, probs):
prob = self.b('prob', value)
prob1 = self.b('prob', value, probs)
prob2 = self.b1('prob', value, probs)
return prob + prob1 + prob2
def test_bernoulli_construct():
"""
Test probability function going through construct.
"""
net = BernoulliConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
assert isinstance(ans, Tensor)
class BernoulliMean(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliBasics, self).__init__()
super(BernoulliMean, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
def construct(self):
mean = self.b.mean()
return mean
def test_mean():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net = BernoulliMean()
ans = net()
assert isinstance(ans, Tensor)
class BernoulliSd(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliSd, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
def construct(self):
sd = self.b.sd()
return sd
def test_sd():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net = BernoulliSd()
ans = net()
assert isinstance(ans, Tensor)
class BernoulliVar(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliVar, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
def construct(self):
var = self.b.var()
return var
def test_var():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net = BernoulliVar()
ans = net()
assert isinstance(ans, Tensor)
class BernoulliMode(nn.Cell):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliMode, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
def construct(self):
mode = self.b.mode()
entropy = self.b.entropy()
return mean + sd + var + mode + entropy
return mode
def test_bascis():
def test_mode():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net = BernoulliBasics()
net = BernoulliMode()
ans = net()
assert isinstance(ans, Tensor)
class BernoulliConstruct(nn.Cell):
class BernoulliEntropy(nn.Cell):
"""
Bernoulli distribution: going through construct.
Test class: basic mean/sd/var/mode/entropy function.
"""
def __init__(self):
super(BernoulliConstruct, self).__init__()
self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
self.b1 = msd.Bernoulli(dtype=dtype.int32)
super(BernoulliEntropy, self).__init__()
self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
def construct(self, value, probs):
prob = self.b('prob', value)
prob1 = self.b('prob', value, probs)
prob2 = self.b1('prob', value, probs)
return prob + prob1 + prob2
def construct(self):
entropy = self.b.entropy()
return entropy
def test_bernoulli_construct():
def test_entropy():
"""
Test probability function going through construct.
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net = BernoulliConstruct()
value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
probs = Tensor([0.5], dtype=dtype.float32)
ans = net(value, probs)
net = BernoulliEntropy()
ans = net()
assert isinstance(ans, Tensor)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册