未验证 提交 669efb98 编写于 作者: L LutaoChu 提交者: GitHub

Fix bug: shapes of Topk outputs are wrong when the parameter k is Tensor

Fix bug: shapes of Topk outputs are wrong when the parameter k is Tensor 
上级 c7e5cf16
......@@ -32,7 +32,6 @@ class TopkV2Op : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("X");
const int& dim_size = input_dims.size();
const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true,
"the axis of topk"
......@@ -41,8 +40,18 @@ class TopkV2Op : public framework::OperatorWithKernel {
if (axis < 0) axis += dim_size;
PADDLE_ENFORCE_GE(
k, 1, "the attribute of k in the topk must >= 1, but received %d .", k);
int k;
auto k_is_tensor = ctx->HasInput("K");
if (k_is_tensor) {
k = -1;
} else {
k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_EQ(k >= 1, true,
"the attribute of k in the topk must >= 1 or be a "
"Tensor, but received %d .",
k);
}
PADDLE_ENFORCE_GE(input_dims.size(), 1,
"input of topk must have >= 1d shape");
......
......@@ -63,28 +63,28 @@ class TestTopkOp(OpTest):
self.check_grad(set(['X']), 'Out')
class TestTopOp1(TestTopkOp):
class TestTopkOp1(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 0
self.largest = True
class TestTopOp2(TestTopkOp):
class TestTopkOp2(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 0
self.largest = False
class TestTopOp3(TestTopkOp):
class TestTopkOp3(TestTopkOp):
def init_args(self):
self.k = 4
self.axis = 0
self.largest = False
class TestTopOp4(TestTopkOp):
class TestTopkOp4(TestTopkOp):
def init_args(self):
self.k = 4
self.axis = 0
......@@ -189,6 +189,8 @@ class TestTopKAPI(unittest.TestCase):
result1 = paddle.topk(input_tensor, k=2)
result2 = paddle.topk(input_tensor, k=2, axis=-1)
result3 = paddle.topk(input_tensor, k=k_tensor, axis=1)
self.assertEqual(result3[0].shape, (6, -1, 8))
self.assertEqual(result3[1].shape, (6, -1, 8))
result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False)
result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False)
result6 = paddle.topk(large_input_tensor, k=1, axis=-1)
......@@ -239,6 +241,15 @@ class TestTopKAPI(unittest.TestCase):
self.run_dygraph(place)
self.run_static(place)
def test_errors(self):
paddle.disable_static()
x = paddle.to_tensor([1, 2, 3])
with self.assertRaises(BaseException):
paddle.topk(x, k=-1)
with self.assertRaises(BaseException):
paddle.topk(x, k=0)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册