提交 0fca5781 编写于 作者: H Hongyu Liu 提交者: phlrain

Merge pull request #16351 from phlrain/fix_topk_shape_check

Fix topk shape check
上级 e61d7245
...@@ -34,8 +34,11 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -34,8 +34,11 @@ class TopkOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape"); PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k, PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k,
"input must have >= k columns"); "input must have >= k columns");
}
framework::DDim dims = input_dims; framework::DDim dims = input_dims;
dims[dims.size() - 1] = k; dims[dims.size() - 1] = k;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册