diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 4b18cce5a2c4d911087e43b95f39a210451f72f8..685b3ea19d4d79de8a928a38580e35abc6078f10 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -127,10 +127,9 @@ def get_bprop_squeeze(self): @bprop_getters.register(P.Flatten) def get_bprop_flatten(self): """Generate bprop for Flatten""" - flatten_grad = G.FlattenGrad() def bprop(x, out, dout): - dx = flatten_grad(dout, shape_op(x)) + dx = reshape(dout, shape_op(x)) return (dx,) return bprop diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 7b018093331149b4674d0facaeb1bfd1c40511ea..7c9568d5df08885c47e7963b28ae17db40080f0c 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -695,7 +695,7 @@ test_case_nn_ops = [ ('Flatten', { 'block': P.Flatten(), 'desc_inputs': [[128, 32, 32, 64]], - 'desc_bprop': [[128 * 32 * 8 * 16]]}), + 'desc_bprop': [[128, 65536]]}), ('LogSoftmax', { 'block': P.LogSoftmax(), 'desc_inputs': [[64, 2]],