diff --git a/tests/ut/python/pynative_mode/test_cell_bprop.py b/tests/ut/python/pynative_mode/test_cell_bprop.py index 054afe36c9d467ee08f2e429f7308f3308d7bc46..da1e14974fc10812a8e3117d5817fb2621a7e71b 100644 --- a/tests/ut/python/pynative_mode/test_cell_bprop.py +++ b/tests/ut/python/pynative_mode/test_cell_bprop.py @@ -64,14 +64,15 @@ def test_grad_inline_mul_add(): class WithParameter(nn.Cell): def __init__(self): super(WithParameter, self).__init__() - self.param = Parameter(2, 'param') + self.param1 = Parameter(1, 'param1') + self.param2 = Parameter(2, 'param2') def construct(self, x, y): - return self.param * x + y + return self.param1 * self.param2 * x + y def bprop(self, x, y, out, dout): # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result - return self.param * dout, 2 * y + return self.param1 * self.param2 * dout, 2 * y def test_with_param(): with_param = WithParameter()