提交 064d370a 编写于 作者: B buxue

fix bug of Acosh bprop and Elu bprop

上级 ef596f26
...@@ -128,7 +128,7 @@ class Validator: ...@@ -128,7 +128,7 @@ class Validator:
@staticmethod @staticmethod
def check_number(arg_name, arg_value, value, rel, prim_name): def check_number(arg_name, arg_value, value, rel, prim_name):
"""Integer value judgment.""" """Number value judgment."""
rel_fn = Rel.get_fns(rel) rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value): if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value) rel_str = Rel.get_strs(rel).format(value)
......
...@@ -727,7 +727,7 @@ def get_bprop_acosh(self): ...@@ -727,7 +727,7 @@ def get_bprop_acosh(self):
input_grad = G.AcoshGrad() input_grad = G.AcoshGrad()
def bprop(x, out, dout): def bprop(x, out, dout):
dx = input_grad(x, dout) dx = input_grad(out, dout)
return (dx,) return (dx,)
return bprop return bprop
......
...@@ -281,7 +281,7 @@ def get_bprop_elu(self): ...@@ -281,7 +281,7 @@ def get_bprop_elu(self):
input_grad = G.EluGrad() input_grad = G.EluGrad()
def bprop(x, out, dout): def bprop(x, out, dout):
dx = input_grad(dout, x) dx = input_grad(dout, out)
return (dx,) return (dx,)
return bprop return bprop
......
...@@ -308,7 +308,8 @@ class Elu(PrimitiveWithInfer): ...@@ -308,7 +308,8 @@ class Elu(PrimitiveWithInfer):
The data type of input tensor should be float. The data type of input tensor should be float.
Args: Args:
alpha (float): The coefficient of negative factor whose type is float. Default: 1.0. alpha (float): The coefficient of negative factor whose type is float,
only support '1.0' currently. Default: 1.0.
Inputs: Inputs:
- **input_x** (Tensor) - The input tensor whose data type should be float. - **input_x** (Tensor) - The input tensor whose data type should be float.
...@@ -328,6 +329,7 @@ class Elu(PrimitiveWithInfer): ...@@ -328,6 +329,7 @@ class Elu(PrimitiveWithInfer):
def __init__(self, alpha=1.0): def __init__(self, alpha=1.0):
"""Init Elu""" """Init Elu"""
validator.check_value_type("alpha", alpha, [float], self.name) validator.check_value_type("alpha", alpha, [float], self.name)
validator.check_number("alpha", alpha, 1.0, Rel.EQ, self.name)
def infer_shape(self, input_x): def infer_shape(self, input_x):
return input_x return input_x
......
...@@ -123,7 +123,7 @@ raise_set = [ ...@@ -123,7 +123,7 @@ raise_set = [
'skip': ['backward']}), 'skip': ['backward']}),
# input is Tensor(int32) # input is Tensor(int32)
('Elu1', { ('Elu1', {
'block': (P.Elu(alpha=0.9), {'exception': TypeError, 'error_keywords': ['Elu']}), 'block': (P.Elu(), {'exception': TypeError, 'error_keywords': ['Elu']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))],
'skip': ['backward']}), 'skip': ['backward']}),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册