未验证 提交 ee9832a7 编写于 作者: Q qingqing01 提交者: GitHub

Add Top-k Python API. (#9973)

* Add topk Python API.

* Add unit test.

* Remove the repeated API.
上级 e5b3eb98
......@@ -815,3 +815,8 @@ zeros
.. autofunction:: paddle.fluid.layers.zeros
:noindex:
topk
----
.. autofunction:: paddle.fluid.layers.topk
:noindex:
......@@ -24,7 +24,6 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
......@@ -36,9 +35,9 @@ class TopkKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor
// FIXME: only deal with matrix(2d tensor).
auto* input = ctx.Input<LoDTensor>("X");
auto* output = ctx.Output<LoDTensor>("Out");
auto* indices = ctx.Output<LoDTensor>("Indices");
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
// k is determined by Attr
const size_t k = static_cast<int>(ctx.Attr<int>("k"));
......
......@@ -32,7 +32,6 @@ __all__ = [
'Switch',
'lod_rank_table',
'max_sequence_len',
'topk',
'lod_tensor_to_array',
'array_to_lod_tensor',
'increment',
......@@ -751,43 +750,6 @@ def max_sequence_len(rank_table):
return res
def topk(input, k):
"""
**topk**
This function performs the operation that selects the k entries in the input
vector and outputs their values and indices as vectors. Thus topk_out[j] is
the j-th largest entry in input, and its index is topk_indices[j]
Args:
input (Variable|list): The input tensor that has all the data.
k (int): The number of top elements that the function will pick.
Returns:
Variable: The variable of type array that contains the k largest entries
from input.
Variable: The variable of type array that contains the indices of k
largest entries from input.
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[10])
k = 5
array = fluid.layers.topk(x, k)
"""
helper = LayerHelper('topk', **locals())
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype='int64')
helper.append_op(
type='top_k',
inputs={'X': [input]},
outputs={'Out': [topk_out],
'Indices': [topk_indices]},
attrs={'k': k})
return topk_out, topk_indices
def lod_tensor_to_array(x, table):
""" Convert a LOD_TENSOR to an LOD_TENSOR_ARRAY.
......
......@@ -20,6 +20,7 @@ from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
import nn
__all__ = ['accuracy', 'auc']
......@@ -27,17 +28,10 @@ __all__ = ['accuracy', 'auc']
def accuracy(input, label, k=1, correct=None, total=None):
"""
This function computes the accuracy using the input and label.
The output is the top_k inputs and their indices.
The output is the top k inputs and their indices.
"""
helper = LayerHelper("accuracy", **locals())
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": k})
topk_out, topk_indices = nn.topk(input, k=k)
acc_out = helper.create_tmp_variable(dtype="float32")
if correct is None:
correct = helper.create_tmp_variable(dtype="int64")
......@@ -68,12 +62,7 @@ def auc(input, label, curve='ROC', num_thresholds=200):
helper = LayerHelper("auc", **locals())
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": k})
topk_out, topk_indices = nn.topk(input, k=k)
auc_out = helper.create_tmp_variable(dtype="float32")
if correct is None:
correct = helper.create_tmp_variable(dtype="int64")
......
......@@ -60,6 +60,7 @@ __all__ = [
'edit_distance',
'l2_normalize',
'matmul',
'topk',
'warpctc',
'sequence_reshape',
'transpose',
......@@ -2576,6 +2577,53 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
return out
def topk(input, k):
"""
This operator is used to find values and indices of the k largest entries
for the last dimension.
If the input is a vector (rank=1), finds the k largest entries in the vector
and outputs their values and indices as vectors. Thus values[j] is the j-th
largest entry in input, and its index is indices[j].
If the input is a Tensor with higher rank, this operator computes the top k
entries along the last dimension.
Args:
input(Variable): The input variable which can be a vector or Tensor with
higher rank.
k(int): An integer value to specify the top k largest elements.
Returns:
values(Variable): The k largest elements along each last dimensional
slice.
indices(Variable): The indices of values within the last dimension of
input.
Examples:
.. code-block:: python
top5_values, top5_indices = layers.topk(input, k=5)
"""
shape = input.shape
if k < 1 and k >= shape[-1]:
raise ValueError("k must be greater than 0 and less than %d." %
(shape[-1]))
helper = LayerHelper("top_k", **locals())
values = helper.create_tmp_variable(dtype=input.dtype)
indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [values],
"Indices": [indices]},
attrs={"k": k})
values.stop_gradient = True
indices.stop_gradient = True
return values, indices
def edit_distance(input, label, normalized=True, ignored_tokens=None,
name=None):
"""
......@@ -2717,15 +2765,7 @@ def ctc_greedy_decoder(input, blank, name=None):
cost = fluid.layers.ctc_greedy_decoder(input=x, blank=0)
"""
helper = LayerHelper("ctc_greedy_decoder", **locals())
# top 1 op
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": 1})
_, topk_indices = topk(input, k=1)
# ctc align op
ctc_out = helper.create_tmp_variable(dtype="int64")
......
......@@ -350,6 +350,15 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(smooth_label)
print(str(program))
def test_topk(self):
program = Program()
with program_guard(program):
data = layers.data(name="label", shape=[200], dtype="float32")
values, indices = layers.topk(data, k=5)
self.assertIsNotNone(values)
self.assertIsNotNone(indices)
print(str(program))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册