diff --git a/python/paddle/fluid/tests/unittests/parallel_margin_cross_entropy.py b/python/paddle/fluid/tests/unittests/parallel_margin_cross_entropy.py index b77a04d8eea9c255269e2119bac3c8052b848b5f..26e9e05b82ab8164a8ffff2cd3f5aa432089828e 100644 --- a/python/paddle/fluid/tests/unittests/parallel_margin_cross_entropy.py +++ b/python/paddle/fluid/tests/unittests/parallel_margin_cross_entropy.py @@ -39,6 +39,7 @@ class TestParallelMarginSoftmaxCrossEntropyOp(unittest.TestCase): def setUp(self): strategy = fleet.DistributedStrategy() fleet.init(is_collective=True, strategy=strategy) + paddle.fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) def test_parallel_margin_softmax_cross_entropy(self): margin1s = [1.0, 1.0, 1.35]