search.py 27.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
C
Chengmo 已提交
14
from __future__ import print_function
15
import numpy as np
C
Chengmo 已提交
16 17
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
18
from ..fluid import core, layers
19

20
# TODO: define searching & indexing functions of a tensor  
21 22
# from ..fluid.layers import has_inf  #DEFINE_ALIAS
# from ..fluid.layers import has_nan  #DEFINE_ALIAS
23

24 25
__all__ = [
    'argmax',
26 27
    'argmin',
    'argsort',
28
    'masked_select',
29
    'topk',
30
    'where',
31 32
    'index_select',
    'nonzero',
C
Chengmo 已提交
33
    'sort',
34
    'index_sample',
35 36 37
]

from paddle.common_ops_import import *
38 39


40 41
def argsort(x, axis=-1, descending=False, name=None):
    """
W
wawltor 已提交
42
    This OP sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

    Args:
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is Rank(x). when axis<0, it works the same way
            as axis+R. Default is 0.
        descending(bool, optional) : Descending is a flag, if set to true,
            algorithm will sort by descending order, else sort by
            ascending order. Default is false.
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.

    Returns:
        Tensor: sorted indices(with the same shape as ``x``
        and with data type int64).

    Examples:
李灿 已提交
62

63
        .. code-block:: python
李灿 已提交
64

65 66
            import paddle
            
67 68 69 70 71 72 73
            x = paddle.to_tensor([[[5,8,9,5],
                                   [0,0,1,7],
                                   [6,9,2,4]],
                                  [[5,2,4,2],
                                   [4,7,7,9],
                                   [1,7,0,6]]], 
                                dtype='float32')
74 75 76
            out1 = paddle.argsort(x=x, axis=-1)
            out2 = paddle.argsort(x=x, axis=0)
            out3 = paddle.argsort(x=x, axis=1)
N
Noel 已提交
77
            print(out1)
W
wawltor 已提交
78 79 80
            #[[[0 3 1 2]
            #  [0 1 2 3]
            #  [2 3 0 1]]
81
            # [[1 3 2 0]
W
wawltor 已提交
82 83
            #  [0 1 2 3]
            #  [2 0 3 1]]]
N
Noel 已提交
84
            print(out2)
W
wawltor 已提交
85 86 87 88 89 90
            #[[[0 1 1 1]
            #  [0 0 0 0]
            #  [1 1 1 0]]
            # [[1 0 0 0]
            #  [1 1 1 1]
            #  [0 0 0 1]]]
N
Noel 已提交
91
            print(out3)
W
wawltor 已提交
92 93 94 95 96 97
            #[[[1 1 1 2]
            #  [0 0 2 0]
            #  [2 2 0 1]]
            # [[2 0 2 0]
            #  [1 1 0 2]
            #  [0 2 1 1]]]
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
    """
    if in_dygraph_mode():
        _, ids = core.ops.argsort(x, 'axis', axis, 'descending', descending)
        return ids
    check_variable_and_dtype(
        x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
        'argsort')

    helper = LayerHelper("argsort", **locals())
    out = helper.create_variable_for_type_inference(
        dtype=x.dtype, stop_gradient=True)
    ids = helper.create_variable_for_type_inference(
        VarDesc.VarType.INT64, stop_gradient=True)
    helper.append_op(
        type='argsort',
        inputs={'X': x},
        outputs={'Out': out,
                 'Indices': ids},
        attrs={'axis': axis,
               'descending': descending})
    return ids


121
def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
122 123 124 125 126
    """
    This OP computes the indices of the max elements of the input tensor's
    element along the provided axis.

    Args:
W
wawltor 已提交
127
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
128 129
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
W
wawltor 已提交
130 131 132
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
        keepdim(bool, optional): Keep the axis that selecting max. The defalut value is False.
133 134 135
        dtype(str|np.dtype, optional): Data type of the output tensor which can
                    be int32, int64. The default value is 'int64', and it will
                    return the int64 indices.
136 137 138
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
139 140

    Returns:
W
wawltor 已提交
141
        Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`
142 143 144 145

    Examples:
        .. code-block:: python

W
wawltor 已提交
146
            import paddle
147

148 149 150
            x =  paddle.to_tensor([[5,8,9,5],
                                     [0,0,1,7],
                                     [6,9,2,4]])
W
wawltor 已提交
151
            out1 = paddle.argmax(x)
N
Noel 已提交
152
            print(out1) # 2
W
wawltor 已提交
153
            out2 = paddle.argmax(x, axis=1)
N
Noel 已提交
154
            print(out2) 
W
wawltor 已提交
155 156
            # [2 3 1]
            out3 = paddle.argmax(x, axis=-1)
N
Noel 已提交
157
            print(out3) 
W
wawltor 已提交
158
            # [2 3 1]
159
    """
160 161 162 163
    if axis is not None and not isinstance(axis, int):
        raise TypeError(
            "The type of 'axis'  must be int or None in argmax, but received %s."
            % (type(axis)))
164

165 166 167 168
    if dtype is None:
        raise ValueError(
            "the value of 'dtype' in argmax could not be None, but received None"
        )
169

170 171
    var_dtype = convert_np_dtype_to_dtype_(dtype)
    check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
W
wawltor 已提交
172 173 174 175 176 177
    flatten = False
    if axis is None:
        flatten = True
        axis = 0

    if in_dygraph_mode():
178 179
        out = core.ops.arg_max(x, 'axis', axis, 'dtype', var_dtype, 'keepdims',
                               keepdim, 'flatten', flatten)
W
wawltor 已提交
180 181 182 183 184 185
        return out

    helper = LayerHelper("argmax", **locals())
    check_variable_and_dtype(
        x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
        'paddle.argmax')
186
    attrs = {}
W
wawltor 已提交
187 188 189 190
    out = helper.create_variable_for_type_inference(var_dtype)
    attrs['keepdims'] = keepdim
    attrs['axis'] = axis
    attrs['flatten'] = flatten
191
    attrs['dtype'] = var_dtype
W
wawltor 已提交
192 193 194 195 196 197
    helper.append_op(
        type='arg_max', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs)
    out.stop_gradient = True
    return out


198
def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
W
wawltor 已提交
199 200 201 202 203 204 205 206 207 208
    """
    This OP computes the indices of the min elements of the input tensor's
    element along the provided axis.

    Args:
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
209
        keepdim(bool, optional): Keep the axis that selecting min. The defalut value is False.
W
wawltor 已提交
210
        dtype(str): Data type of the output tensor which can
211
                    be int32, int64. The default value is 'int64', and it will
W
wawltor 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224
                    return the int64 indices.
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.

    Returns:
        Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`

    Examples:
        .. code-block:: python

            import paddle

225 226 227
            x =  paddle.to_tensor([[5,8,9,5],
                                     [0,0,1,7],
                                     [6,9,2,4]])
W
wawltor 已提交
228
            out1 = paddle.argmin(x)
N
Noel 已提交
229
            print(out1) # 4
W
wawltor 已提交
230
            out2 = paddle.argmin(x, axis=1)
N
Noel 已提交
231
            print(out2) 
W
wawltor 已提交
232 233
            # [0 0 2]
            out3 = paddle.argmin(x, axis=-1)
N
Noel 已提交
234
            print(out3) 
W
wawltor 已提交
235 236
            # [0 0 2]
    """
237 238 239 240
    if axis is not None and not isinstance(axis, int):
        raise TypeError(
            "The type of 'axis'  must be int or None in argmin, but received %s."
            % (type(axis)))
241

242 243 244 245
    if dtype is None:
        raise ValueError(
            "the value of 'dtype' in argmin could not be None, but received None"
        )
246

247 248
    var_dtype = convert_np_dtype_to_dtype_(dtype)
    check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
W
wawltor 已提交
249
    flatten = False
250
    if axis is None:
W
wawltor 已提交
251 252 253 254
        flatten = True
        axis = 0

    if in_dygraph_mode():
255 256
        out = core.ops.arg_min(x, 'axis', axis, 'dtype', var_dtype, 'keepdims',
                               keepdim, 'flatten', flatten)
W
wawltor 已提交
257 258 259 260 261 262 263
        return out

    helper = LayerHelper("argmin", **locals())
    check_variable_and_dtype(
        x, 'x', ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
        'paddle.argmin')
    out = helper.create_variable_for_type_inference(var_dtype)
264
    attrs = {}
W
wawltor 已提交
265
    attrs['keepdims'] = keepdim
266
    attrs['axis'] = axis
W
wawltor 已提交
267
    attrs['flatten'] = flatten
268
    attrs['dtype'] = var_dtype
269
    helper.append_op(
W
wawltor 已提交
270
        type='arg_min', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs)
271 272
    out.stop_gradient = True
    return out
273 274


275
def index_select(x, index, axis=0, name=None):
276
    """
S
swtkiwi 已提交
277

278 279 280 281
    Returns a new tensor which indexes the ``input`` tensor along dimension ``axis`` using 
    the entries in ``index`` which is a Tensor. The returned tensor has the same number 
    of dimensions as the original ``x`` tensor. The dim-th dimension has the same 
    size as the length of ``index``; other dimensions have the same size as in the ``x`` tensor. 
C
Chengmo 已提交
282

283
    Args:
284 285 286
        x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float32, float64, int32, int64.
        index (Tensor): The 1-D Tensor containing the indices to index. The data type of ``index`` must be int32 or int64.
        axis (int, optional): The dimension in which we index. Default: if None, the ``axis`` is 0.
287 288 289
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
290 291

    Returns:
292
        Tensor: A Tensor with same data type as ``x``.
293
    
294 295
    Examples:
        .. code-block:: python
296
            
297 298
            import paddle

299 300 301 302
            x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
                                  [5.0, 6.0, 7.0, 8.0],
                                  [9.0, 10.0, 11.0, 12.0]])
            index = paddle.to_tensor([0, 1, 1], dtype='int32')
303 304 305 306 307 308 309 310
            out_z1 = paddle.index_select(x=x, index=index)
            #[[1. 2. 3. 4.]
            # [5. 6. 7. 8.]
            # [5. 6. 7. 8.]]
            out_z2 = paddle.index_select(x=x, index=index, axis=1)
            #[[ 1.  2.  2.]
            # [ 5.  6.  6.]
            # [ 9. 10. 10.]]
311
    """
312

313
    if in_dygraph_mode():
314
        return core.ops.index_select(x, index, 'dim', axis)
315

316 317 318
    helper = LayerHelper("index_select", **locals())
    check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
                             'paddle.tensor.search.index_select')
319
    check_variable_and_dtype(index, 'index', ['int32', 'int64'],
320
                             'paddle.tensor.search.index_select')
321

322
    out = helper.create_variable_for_type_inference(x.dtype)
323 324 325

    helper.append_op(
        type='index_select',
326
        inputs={'X': x,
327 328
                'Index': index},
        outputs={'Out': out},
329
        attrs={'dim': axis})
330 331 332
    return out


333
def nonzero(x, as_tuple=False):
334 335 336 337 338 339 340 341
    """
    Return a tensor containing the indices of all non-zero elements of the `input` 
    tensor. If as_tuple is True, return a tuple of 1-D tensors, one for each dimension 
    in `input`, each containing the indices (in that dimension) of all non-zero elements 
    of `input`. Given a n-Dimensional `input` tensor with shape [x_1, x_2, ..., x_n], If 
    as_tuple is False, we can get a output tensor with shape [z, n], where `z` is the 
    number of all non-zero elements in the `input` tensor. If as_tuple is True, we can get 
    a 1-D tensor tuple of length `n`, and the shape of each 1-D tensor is [z, 1].
C
Chengmo 已提交
342

343
    Args:
344
        x (Tensor): The input tensor variable.
345 346 347
        as_tuple (bool): Return type, Tensor or tuple of Tensor.

    Returns:
348
        Tensor. The data type is int64.
349 350

    Examples:
351

N
Noel 已提交
352
        .. code-block:: python
李灿 已提交
353

354
            import paddle
355 356

            x1 = paddle.to_tensor([[1.0, 0.0, 0.0],
N
Noel 已提交
357 358
                                   [0.0, 2.0, 0.0],
                                   [0.0, 0.0, 3.0]])
359 360
            x2 = paddle.to_tensor([0.0, 1.0, 0.0, 3.0])
            out_z1 = paddle.nonzero(x1)
N
Noel 已提交
361
            print(out_z1)
362 363 364 365 366
            #[[0 0]
            # [1 1]
            # [2 2]]
            out_z1_tuple = paddle.nonzero(x1, as_tuple=True)
            for out in out_z1_tuple:
N
Noel 已提交
367
                print(out)
368 369 370 371 372 373 374
            #[[0]
            # [1]
            # [2]]
            #[[0]
            # [1]
            # [2]]
            out_z2 = paddle.nonzero(x2)
N
Noel 已提交
375
            print(out_z2)
376 377 378 379
            #[[1]
            # [3]]
            out_z2_tuple = paddle.nonzero(x2, as_tuple=True)
            for out in out_z2_tuple:
N
Noel 已提交
380
                print(out)
381 382
            #[[1]
            # [3]]
N
Noel 已提交
383

384 385
    """
    list_out = []
386
    shape = x.shape
387 388 389
    rank = len(shape)

    if in_dygraph_mode():
390
        outs = core.ops.where_index(x)
391
    else:
392
        outs = layers.where(x)
393 394 395 396 397 398 399 400 401 402 403 404 405

    if not as_tuple:
        return outs
    elif rank == 1:
        return tuple([outs])
    else:
        for i in range(rank):
            list_out.append(
                layers.slice(
                    outs, axes=[rank - 1], starts=[i], ends=[i + 1]))
        return tuple(list_out)


406
def sort(x, axis=-1, descending=False, name=None):
407
    """
S
swtkiwi 已提交
408

W
wawltor 已提交
409
    This OP sorts the input along the given axis, and returns the sorted output tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
C
Chengmo 已提交
410

411
    Args:
412
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
413 414 415 416 417 418 419 420 421 422 423
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is Rank(x). when axis<0, it works the same way
            as axis+R. Default is 0.
        descending(bool, optional) : Descending is a flag, if set to true,
            algorithm will sort by descending order, else sort by
            ascending order. Default is false.
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
    Returns:
W
wawltor 已提交
424
        Tensor: sorted tensor(with the same shape and data type as ``x``).
425
    Examples:
N
Noel 已提交
426

427
        .. code-block:: python
N
Noel 已提交
428

429
            import paddle
N
Noel 已提交
430

431 432 433 434 435 436 437
            x = paddle.to_tensor([[[5,8,9,5],
                                   [0,0,1,7],
                                   [6,9,2,4]],
                                  [[5,2,4,2],
                                   [4,7,7,9],
                                   [1,7,0,6]]], 
                                 dtype='float32')
438 439 440
            out1 = paddle.sort(x=x, axis=-1)
            out2 = paddle.sort(x=x, axis=0)
            out3 = paddle.sort(x=x, axis=1)
N
Noel 已提交
441
            print(out1)
W
wawltor 已提交
442 443 444 445 446 447
            #[[[5. 5. 8. 9.]
            #  [0. 0. 1. 7.]
            #  [2. 4. 6. 9.]]
            # [[2. 2. 4. 5.]
            #  [4. 7. 7. 9.]
            #  [0. 1. 6. 7.]]]
N
Noel 已提交
448
            print(out2)
449
            #[[[5. 2. 4. 2.]
W
wawltor 已提交
450 451 452 453 454
            #  [0. 0. 1. 7.]
            #  [1. 7. 0. 4.]]
            # [[5. 8. 9. 5.]
            #  [4. 7. 7. 9.]
            #  [6. 9. 2. 6.]]]
N
Noel 已提交
455
            print(out3)
456
            #[[[0. 0. 1. 4.]
W
wawltor 已提交
457 458 459 460 461
            #  [5. 8. 2. 5.]
            #  [6. 9. 9. 7.]]
            # [[1. 2. 0. 2.]
            #  [4. 7. 4. 6.]
            #  [5. 7. 7. 9.]]]
462
    """
463
    if in_dygraph_mode():
W
wawltor 已提交
464 465
        out, _ = core.ops.argsort(x, 'axis', axis, 'descending', descending)
        return out
466
    helper = LayerHelper("sort", **locals())
467 468
    out = helper.create_variable_for_type_inference(
        dtype=x.dtype, stop_gradient=False)
469 470 471 472
    ids = helper.create_variable_for_type_inference(
        VarDesc.VarType.INT64, stop_gradient=True)
    helper.append_op(
        type='argsort',
473
        inputs={'X': x},
474 475 476 477
        outputs={'Out': out,
                 'Indices': ids},
        attrs={'axis': axis,
               'descending': descending})
W
wawltor 已提交
478
    return out
C
Chengmo 已提交
479 480


481
def where(condition, x, y, name=None):
482
    r"""
483 484 485
    Return a tensor of elements selected from either $x$ or $y$, depending on $condition$.

    .. math::
C
Chengmo 已提交
486

487 488 489 490 491
      out_i =
      \\begin{cases}
      x_i, \quad  \\text{if}  \\ condition_i \\  is \\ True \\\\
      y_i, \quad  \\text{if}  \\ condition_i \\  is \\ False \\\\
      \\end{cases}
C
Chengmo 已提交
492

493

494
    Args:
G
GaoWei8 已提交
495 496 497
        condition(Tensor): The condition to choose x or y.
        x(Tensor): x is a Tensor with data type float32, float64, int32, int64.
        y(Tensor): y is a Tensor with data type float32, float64, int32, int64.
498 499 500 501 502

        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.

503
    Returns:
G
GaoWei8 已提交
504
        Tensor: A Tensor with the same data dype as x. 
505

506 507 508
    Examples:
        .. code-block:: python

G
GaoWei8 已提交
509
          import paddle
510

511 512 513
          x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2])
          y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0])
          out = paddle.where(x>1, x, y)
514

G
GaoWei8 已提交
515
          print(out)
516
          #out: [1.0, 1.0, 3.2, 1.2]
517 518
    """
    if not in_dygraph_mode():
519
        check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
520
        check_variable_and_dtype(
521
            x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where')
522
        check_variable_and_dtype(
523
            y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where')
524

525 526 527
    x_shape = list(x.shape)
    y_shape = list(y.shape)
    if x_shape == y_shape:
528
        if in_dygraph_mode():
529
            return core.ops.where(condition, x, y)
530 531
        else:
            helper = LayerHelper("where", **locals())
G
GaoWei8 已提交
532
            out = helper.create_variable_for_type_inference(dtype=x.dtype)
533 534 535

            helper.append_op(
                type='where',
536 537 538
                inputs={'Condition': condition,
                        'X': x,
                        'Y': y},
539 540 541
                outputs={'Out': [out]})
            return out
    else:
542 543 544 545
        cond_int = layers.cast(condition, x.dtype)
        cond_not_int = layers.cast(layers.logical_not(condition), x.dtype)
        out1 = layers.elementwise_mul(x, cond_int)
        out2 = layers.elementwise_mul(y, cond_not_int)
546 547 548 549
        out = layers.elementwise_add(out1, out2)
        return out


C
Chengmo 已提交
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
def index_sample(x, index):
    """
    **IndexSample Layer**

    IndexSample OP returns the element of the specified location of X, 
    and the location is specified by Index. 

    .. code-block:: text


                Given:

                X = [[1, 2, 3, 4, 5],
                     [6, 7, 8, 9, 10]]

                Index = [[0, 1, 3],
                         [0, 2, 4]]

                Then:

                Out = [[1, 2, 4],
                       [6, 8, 10]]

    Args:
C
Chengmo 已提交
574
        x (Tensor): The source input tensor with 2-D shape. Supported data type is 
C
Chengmo 已提交
575
            int32, int64, float32, float64.
C
Chengmo 已提交
576
        index (Tensor): The index input tensor with 2-D shape, first dimension should be same with X. 
C
Chengmo 已提交
577 578 579
            Data type is int32 or int64.

    Returns:
C
Chengmo 已提交
580
        output (Tensor): The output is a tensor with the same shape as index.
C
Chengmo 已提交
581 582 583 584 585 586

    Examples:

        .. code-block:: python

            import paddle
587 588 589 590 591 592 593 594 595 596 597

            x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
                                  [5.0, 6.0, 7.0, 8.0],
                                  [9.0, 10.0, 11.0, 12.0]], dtype='float32')
            index = paddle.to_tensor([[0, 1, 2],
                                      [1, 2, 3],
                                      [0, 0, 0]], dtype='int32')
            target = paddle.to_tensor([[100, 200, 300, 400],
                                       [500, 600, 700, 800],
                                       [900, 1000, 1100, 1200]], dtype='int32')
            out_z1 = paddle.index_sample(x, index)
N
Noel 已提交
598
            print(out_z1)
599 600 601 602 603 604 605 606
            #[[1. 2. 3.]
            # [6. 7. 8.]
            # [9. 9. 9.]]

            # Use the index of the maximum value by topk op
            # get the value of the element of the corresponding index in other tensors
            top_value, top_index = paddle.topk(x, k=2)
            out_z2 = paddle.index_sample(target, top_index)
N
Noel 已提交
607
            print(top_value)
608 609 610 611
            #[[ 4.  3.]
            # [ 8.  7.]
            # [12. 11.]]

N
Noel 已提交
612
            print(top_index)
613 614 615 616
            #[[3 2]
            # [3 2]
            # [3 2]]

N
Noel 已提交
617
            print(out_z2)
618 619 620
            #[[ 400  300]
            # [ 800  700]
            # [1200 1100]]
C
Chengmo 已提交
621

C
Chengmo 已提交
622
    """
C
Chengmo 已提交
623 624 625
    if in_dygraph_mode():
        return core.ops.index_sample(x, index)

C
Chengmo 已提交
626 627 628 629 630 631 632 633 634 635 636 637 638
    helper = LayerHelper("index_sample", **locals())
    check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
                             'paddle.tensor.search.index_sample')
    check_variable_and_dtype(index, 'index', ['int32', 'int64'],
                             'paddle.tensor.search.index_sample')
    out = helper.create_variable_for_type_inference(dtype=x.dtype)

    helper.append_op(
        type='index_sample',
        inputs={'X': x,
                'Index': index},
        outputs={'Out': out})
    return out
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659


def masked_select(x, mask, name=None):
    """
    This OP Returns a new 1-D tensor which indexes the input tensor according to the ``mask``
    which is a tensor with data type of bool.

    Args:
        x (Tensor): The input Tensor, the data type can be int32, int64, float32, float64. 
        mask (Tensor): The Tensor containing the binary mask to index with, it's data type is bool.
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.

    Returns: A 1-D Tensor which is the same data type  as ``x``.
    
    Examples:

        .. code-block:: python

            import paddle
660 661 662 663 664 665 666

            x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
                                  [5.0, 6.0, 7.0, 8.0],
                                  [9.0, 10.0, 11.0, 12.0]])
            mask = paddle.to_tensor([[True, False, False, False],
                                     [True, True, False, False],
                                     [True, False, False, False]])
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683
            out = paddle.masked_select(x, mask)
            #[1.0 5.0 6.0 9.0]
    """

    if in_dygraph_mode():
        return core.ops.masked_select(x, mask)

    helper = LayerHelper("masked_select", **locals())
    check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
                             'paddle.tensor.search.mask_select')
    check_variable_and_dtype(mask, 'mask', ['bool'],
                             'paddle.tensor.search.masked_select')
    out = helper.create_variable_for_type_inference(dtype=x.dtype)
    helper.append_op(
        type='masked_select', inputs={'X': x,
                                      'Mask': mask}, outputs={'Y': out})
    return out
W
wawltor 已提交
684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712


def topk(x, k, axis=None, largest=True, sorted=True, name=None):
    """
    This OP is used to find values and indices of the k largest or smallest at the optional axis.
    If the input is a 1-D Tensor, finds the k largest or smallest values and indices.
    If the input is a Tensor with higher rank, this operator computes the top k values and indices along the :attr:`axis`.

    Args:
        x(Tensor): Tensor, an input N-D Tensor with type float32, float64, int32, int64.
        k(int, Tensor): The number of top elements to look for along the axis.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is -1.
        largest(bool, optional) : largest is a flag, if set to true,
            algorithm will sort by descending order, otherwise sort by
            ascending order. Default is True.
        sorted(bool, optional): controls whether to return the elements in sorted order, default value is True. In gpu device, it always return the sorted value. 
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.

    Examples:

        .. code-block:: python

           import paddle

713
           tensor_1 = paddle.to_tensor([1, 4, 5, 7])
W
wawltor 已提交
714
           value_1, indices_1 = paddle.topk(tensor_1, k=1)
N
Noel 已提交
715
           print(value_1)
W
wawltor 已提交
716
           # [7]
N
Noel 已提交
717
           print(indices_1)
W
wawltor 已提交
718
           # [3] 
719
           tensor_2 = paddle.to_tensor([[1, 4, 5, 7], [2, 6, 2, 5]])
W
wawltor 已提交
720
           value_2, indices_2 = paddle.topk(tensor_2, k=1)
N
Noel 已提交
721
           print(value_2)
W
wawltor 已提交
722 723
           # [[7]
           #  [6]]
N
Noel 已提交
724
           print(indices_2)
W
wawltor 已提交
725 726 727
           # [[3]
           #  [1]]
           value_3, indices_3 = paddle.topk(tensor_2, k=1, axis=-1)
N
Noel 已提交
728
           print(value_3)
W
wawltor 已提交
729 730
           # [[7]
           #  [6]]
N
Noel 已提交
731
           print(indices_3)
W
wawltor 已提交
732 733 734
           # [[3]
           #  [1]]
           value_4, indices_4 = paddle.topk(tensor_2, k=1, axis=0)
N
Noel 已提交
735
           print(value_4)
W
wawltor 已提交
736
           # [[2 6 5 7]]
N
Noel 已提交
737
           print(indices_4)
W
wawltor 已提交
738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
           # [[1 1 0 0]]

    """
    if in_dygraph_mode():
        k = k.numpy().item(0) if isinstance(k, Variable) else k
        if axis is None:
            out, indices = core.ops.top_k_v2(x, 'k',
                                             int(k), 'largest', largest,
                                             'sorted', sorted)
        else:
            out, indices = core.ops.top_k_v2(x, 'k',
                                             int(k), 'axis', axis, 'largest',
                                             largest, 'sorted', sorted)
        return out, indices

    helper = LayerHelper("top_k_v2", **locals())
    inputs = {"X": [x]}
    attrs = {}
    if isinstance(k, Variable):
        inputs['K'] = [k]
    else:
        attrs = {'k': k}
    attrs['largest'] = largest
    attrs['sorted'] = sorted
    if axis is not None:
        attrs['axis'] = axis

    values = helper.create_variable_for_type_inference(dtype=x.dtype)
    indices = helper.create_variable_for_type_inference(dtype="int64")

    helper.append_op(
        type="top_k_v2",
        inputs=inputs,
        outputs={"Out": [values],
                 "Indices": [indices]},
        attrs=attrs)
    indices.stop_gradient = True
    return values, indices