提交 064d370a 编写于 作者: B buxue

fix bug of Acosh bprop and Elu bprop

上级 ef596f26
......@@ -128,7 +128,7 @@ class Validator:
@staticmethod
def check_number(arg_name, arg_value, value, rel, prim_name):
"""Integer value judgment."""
"""Number value judgment."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
......
......@@ -727,7 +727,7 @@ def get_bprop_acosh(self):
input_grad = G.AcoshGrad()
def bprop(x, out, dout):
dx = input_grad(x, dout)
dx = input_grad(out, dout)
return (dx,)
return bprop
......
......@@ -281,7 +281,7 @@ def get_bprop_elu(self):
input_grad = G.EluGrad()
def bprop(x, out, dout):
dx = input_grad(dout, x)
dx = input_grad(dout, out)
return (dx,)
return bprop
......
......@@ -308,7 +308,8 @@ class Elu(PrimitiveWithInfer):
The data type of input tensor should be float.
Args:
alpha (float): The coefficient of negative factor whose type is float. Default: 1.0.
alpha (float): The coefficient of negative factor whose type is float,
only support '1.0' currently. Default: 1.0.
Inputs:
- **input_x** (Tensor) - The input tensor whose data type should be float.
......@@ -328,6 +329,7 @@ class Elu(PrimitiveWithInfer):
def __init__(self, alpha=1.0):
"""Init Elu"""
validator.check_value_type("alpha", alpha, [float], self.name)
validator.check_number("alpha", alpha, 1.0, Rel.EQ, self.name)
def infer_shape(self, input_x):
return input_x
......
......@@ -123,7 +123,7 @@ raise_set = [
'skip': ['backward']}),
# input is Tensor(int32)
('Elu1', {
'block': (P.Elu(alpha=0.9), {'exception': TypeError, 'error_keywords': ['Elu']}),
'block': (P.Elu(), {'exception': TypeError, 'error_keywords': ['Elu']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))],
'skip': ['backward']}),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册