diff --git a/paddle/fluid/operators/top_k_v2_op.cc b/paddle/fluid/operators/top_k_v2_op.cc index cc72d83411f5a34561a75e7e75f98077ee5a4e5d..0e3fcced19ea8eb1580ca93fa9d6616685601f75 100644 --- a/paddle/fluid/operators/top_k_v2_op.cc +++ b/paddle/fluid/operators/top_k_v2_op.cc @@ -32,7 +32,6 @@ class TopkV2Op : public framework::OperatorWithKernel { auto input_dims = ctx->GetInputDim("X"); const int& dim_size = input_dims.size(); - const int k = static_cast(ctx->Attrs().Get("k")); int axis = static_cast(ctx->Attrs().Get("axis")); PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true, "the axis of topk" @@ -41,8 +40,18 @@ class TopkV2Op : public framework::OperatorWithKernel { if (axis < 0) axis += dim_size; - PADDLE_ENFORCE_GE( - k, 1, "the attribute of k in the topk must >= 1, but received %d .", k); + int k; + auto k_is_tensor = ctx->HasInput("K"); + if (k_is_tensor) { + k = -1; + } else { + k = static_cast(ctx->Attrs().Get("k")); + PADDLE_ENFORCE_EQ(k >= 1, true, + "the attribute of k in the topk must >= 1 or be a " + "Tensor, but received %d .", + k); + } + PADDLE_ENFORCE_GE(input_dims.size(), 1, "input of topk must have >= 1d shape"); diff --git a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py index 54e7765c0fb76844a6123fceea6c1ef79dc0c2bf..b9d96f329b5bb48f7167d005f11f64136fdf5d01 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_v2_op.py @@ -63,28 +63,28 @@ class TestTopkOp(OpTest): self.check_grad(set(['X']), 'Out') -class TestTopOp1(TestTopkOp): +class TestTopkOp1(TestTopkOp): def init_args(self): self.k = 3 self.axis = 0 self.largest = True -class TestTopOp2(TestTopkOp): +class TestTopkOp2(TestTopkOp): def init_args(self): self.k = 3 self.axis = 0 self.largest = False -class TestTopOp3(TestTopkOp): +class TestTopkOp3(TestTopkOp): def init_args(self): self.k = 4 self.axis = 0 self.largest = False -class TestTopOp4(TestTopkOp): +class TestTopkOp4(TestTopkOp): def init_args(self): self.k = 4 self.axis = 0 @@ -189,6 +189,8 @@ class TestTopKAPI(unittest.TestCase): result1 = paddle.topk(input_tensor, k=2) result2 = paddle.topk(input_tensor, k=2, axis=-1) result3 = paddle.topk(input_tensor, k=k_tensor, axis=1) + self.assertEqual(result3[0].shape, (6, -1, 8)) + self.assertEqual(result3[1].shape, (6, -1, 8)) result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False) result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False) result6 = paddle.topk(large_input_tensor, k=1, axis=-1) @@ -239,6 +241,15 @@ class TestTopKAPI(unittest.TestCase): self.run_dygraph(place) self.run_static(place) + def test_errors(self): + paddle.disable_static() + x = paddle.to_tensor([1, 2, 3]) + with self.assertRaises(BaseException): + paddle.topk(x, k=-1) + + with self.assertRaises(BaseException): + paddle.topk(x, k=0) + if __name__ == "__main__": unittest.main()