未验证 提交 669efb98 编写于 作者: L LutaoChu 提交者: GitHub

Fix bug: shapes of Topk outputs are wrong when the parameter k is Tensor

Fix bug: shapes of Topk outputs are wrong when the parameter k is Tensor 
上级 c7e5cf16
...@@ -32,7 +32,6 @@ class TopkV2Op : public framework::OperatorWithKernel { ...@@ -32,7 +32,6 @@ class TopkV2Op : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
const int& dim_size = input_dims.size(); const int& dim_size = input_dims.size();
const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis")); int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true, PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true,
"the axis of topk" "the axis of topk"
...@@ -41,8 +40,18 @@ class TopkV2Op : public framework::OperatorWithKernel { ...@@ -41,8 +40,18 @@ class TopkV2Op : public framework::OperatorWithKernel {
if (axis < 0) axis += dim_size; if (axis < 0) axis += dim_size;
PADDLE_ENFORCE_GE( int k;
k, 1, "the attribute of k in the topk must >= 1, but received %d .", k); auto k_is_tensor = ctx->HasInput("K");
if (k_is_tensor) {
k = -1;
} else {
k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_EQ(k >= 1, true,
"the attribute of k in the topk must >= 1 or be a "
"Tensor, but received %d .",
k);
}
PADDLE_ENFORCE_GE(input_dims.size(), 1, PADDLE_ENFORCE_GE(input_dims.size(), 1,
"input of topk must have >= 1d shape"); "input of topk must have >= 1d shape");
......
...@@ -63,28 +63,28 @@ class TestTopkOp(OpTest): ...@@ -63,28 +63,28 @@ class TestTopkOp(OpTest):
self.check_grad(set(['X']), 'Out') self.check_grad(set(['X']), 'Out')
class TestTopOp1(TestTopkOp): class TestTopkOp1(TestTopkOp):
def init_args(self): def init_args(self):
self.k = 3 self.k = 3
self.axis = 0 self.axis = 0
self.largest = True self.largest = True
class TestTopOp2(TestTopkOp): class TestTopkOp2(TestTopkOp):
def init_args(self): def init_args(self):
self.k = 3 self.k = 3
self.axis = 0 self.axis = 0
self.largest = False self.largest = False
class TestTopOp3(TestTopkOp): class TestTopkOp3(TestTopkOp):
def init_args(self): def init_args(self):
self.k = 4 self.k = 4
self.axis = 0 self.axis = 0
self.largest = False self.largest = False
class TestTopOp4(TestTopkOp): class TestTopkOp4(TestTopkOp):
def init_args(self): def init_args(self):
self.k = 4 self.k = 4
self.axis = 0 self.axis = 0
...@@ -189,6 +189,8 @@ class TestTopKAPI(unittest.TestCase): ...@@ -189,6 +189,8 @@ class TestTopKAPI(unittest.TestCase):
result1 = paddle.topk(input_tensor, k=2) result1 = paddle.topk(input_tensor, k=2)
result2 = paddle.topk(input_tensor, k=2, axis=-1) result2 = paddle.topk(input_tensor, k=2, axis=-1)
result3 = paddle.topk(input_tensor, k=k_tensor, axis=1) result3 = paddle.topk(input_tensor, k=k_tensor, axis=1)
self.assertEqual(result3[0].shape, (6, -1, 8))
self.assertEqual(result3[1].shape, (6, -1, 8))
result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False) result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False)
result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False) result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False)
result6 = paddle.topk(large_input_tensor, k=1, axis=-1) result6 = paddle.topk(large_input_tensor, k=1, axis=-1)
...@@ -239,6 +241,15 @@ class TestTopKAPI(unittest.TestCase): ...@@ -239,6 +241,15 @@ class TestTopKAPI(unittest.TestCase):
self.run_dygraph(place) self.run_dygraph(place)
self.run_static(place) self.run_static(place)
def test_errors(self):
paddle.disable_static()
x = paddle.to_tensor([1, 2, 3])
with self.assertRaises(BaseException):
paddle.topk(x, k=-1)
with self.assertRaises(BaseException):
paddle.topk(x, k=0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册