diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 6a0d955040f353113598e63a52ab2989b79b8506..c4be262e93029cf47efcbc67482d4481838b39f0 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1208,7 +1208,7 @@ class CrossEntropyLoss(unittest.TestCase): self.assertIsNotNone(static_ret) with fluid.dygraph.guard(): cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( - weight=fluid.dygraph.to_variable(weight_np), reduction='mean') + weight=fluid.dygraph.to_variable(weight_np), reduction='mean', axis=1) dy_ret = cross_entropy_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np))