From b4f49aa196d246e629b19707a972c8ee4a09184c Mon Sep 17 00:00:00 2001 From: wangxiaoning <71813629+wangxn12138@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:59:43 +0800 Subject: [PATCH] fix rank=1 (#51413) --- paddle/phi/infermeta/ternary.cc | 2 +- python/paddle/fluid/tests/unittests/test_gather_nd_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 9f787d07753..d2f1c78eb0b 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1073,7 +1073,7 @@ void ScatterNdAddInferMeta(const MetaTensor& x, index_dims[index_dims_size - 1], ref_dims_size)); PADDLE_ENFORCE_GE(index_dims_size, - 2UL, + 1UL, phi::errors::InvalidArgument( "The rank of Input(Index) should be greater than 1, " "but received the rank of Input(Index) is %d.", diff --git a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py index 4916a814a23..6fc85d1597d 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py @@ -54,7 +54,7 @@ class TestGatherNdOpWithIndex1(OpTest): self.check_output(check_eager=False) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_eager=False) + self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) class TestGatherNdOpWithLowIndex(OpTest): -- GitLab