search.py 41.6 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import numpy as np
Z
zhiboniu 已提交
15
import paddle
16
from ..framework import LayerHelper, convert_np_dtype_to_dtype_
17
from ..fluid.data_feeder import check_dtype, check_variable_and_dtype
18 19
from ..framework import core, in_dygraph_mode, _non_static_mode
from ..fluid.framework import _in_legacy_dygraph
20 21
from paddle.common_ops_import import Variable
from paddle.common_ops_import import VarDesc
22
from paddle import _C_ops, _legacy_C_ops
23

24
# TODO: define searching & indexing functions of a tensor
25 26
# from ..fluid.layers import has_inf  #DEFINE_ALIAS
# from ..fluid.layers import has_nan  #DEFINE_ALIAS
27

28 29
__all__ = []

30

31 32
def argsort(x, axis=-1, descending=False, name=None):
    """
33
    Sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
34 35 36 37 38 39

    Args:
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is Rank(x). when axis<0, it works the same way
C
Chen Long 已提交
40
            as axis+R. Default is -1.
41 42 43
        descending(bool, optional) : Descending is a flag, if set to true,
            algorithm will sort by descending order, else sort by
            ascending order. Default is false.
44
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
45 46 47 48 49 50

    Returns:
        Tensor: sorted indices(with the same shape as ``x``
        and with data type int64).

    Examples:
李灿 已提交
51

52
        .. code-block:: python
李灿 已提交
53

54
            import paddle
55

56 57 58 59 60
            x = paddle.to_tensor([[[5,8,9,5],
                                   [0,0,1,7],
                                   [6,9,2,4]],
                                  [[5,2,4,2],
                                   [4,7,7,9],
61
                                   [1,7,0,6]]],
62
                                dtype='float32')
C
Chen Long 已提交
63 64 65
            out1 = paddle.argsort(x, axis=-1)
            out2 = paddle.argsort(x, axis=0)
            out3 = paddle.argsort(x, axis=1)
66

N
Noel 已提交
67
            print(out1)
W
wawltor 已提交
68 69 70
            #[[[0 3 1 2]
            #  [0 1 2 3]
            #  [2 3 0 1]]
71
            # [[1 3 2 0]
W
wawltor 已提交
72 73
            #  [0 1 2 3]
            #  [2 0 3 1]]]
74

N
Noel 已提交
75
            print(out2)
W
wawltor 已提交
76 77 78 79 80 81
            #[[[0 1 1 1]
            #  [0 0 0 0]
            #  [1 1 1 0]]
            # [[1 0 0 0]
            #  [1 1 1 1]
            #  [0 0 0 1]]]
82

N
Noel 已提交
83
            print(out3)
W
wawltor 已提交
84 85 86 87 88 89
            #[[[1 1 1 2]
            #  [0 0 2 0]
            #  [2 2 0 1]]
            # [[2 0 2 0]
            #  [1 1 0 2]
            #  [0 2 1 1]]]
90
    """
H
hong 已提交
91
    if in_dygraph_mode():
92
        _, ids = _C_ops.argsort(x, axis, descending)
H
hong 已提交
93 94 95
        return ids

    if _in_legacy_dygraph():
96 97 98
        _, ids = _legacy_C_ops.argsort(
            x, 'axis', axis, 'descending', descending
        )
99 100
        return ids
    check_variable_and_dtype(
101 102 103 104 105
        x,
        'x',
        ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
        'argsort',
    )
106 107

    helper = LayerHelper("argsort", **locals())
108 109 110 111 112 113 114 115 116 117 118 119
    out = helper.create_variable_for_type_inference(
        dtype=x.dtype, stop_gradient=True
    )
    ids = helper.create_variable_for_type_inference(
        VarDesc.VarType.INT64, stop_gradient=True
    )
    helper.append_op(
        type='argsort',
        inputs={'X': x},
        outputs={'Out': out, 'Indices': ids},
        attrs={'axis': axis, 'descending': descending},
    )
120 121 122
    return ids


123
def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
124
    """
125
    Computes the indices of the max elements of the input tensor's
126 127 128
    element along the provided axis.

    Args:
W
wawltor 已提交
129
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
130 131
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
W
wawltor 已提交
132 133
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
134
        keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
135
        dtype(str|np.dtype, optional): Data type of the output tensor which can
136
                    be int32, int64. The default value is ``int64`` , and it will
137
                    return the int64 indices.
138
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
139 140

    Returns:
141
        Tensor, return the tensor of int32 if set :attr:`dtype` is int32, otherwise return the tensor of int64.
142 143 144 145

    Examples:
        .. code-block:: python

W
wawltor 已提交
146
            import paddle
147

148 149 150
            x = paddle.to_tensor([[5,8,9,5],
                                 [0,0,1,7],
                                 [6,9,2,4]])
W
wawltor 已提交
151
            out1 = paddle.argmax(x)
N
Noel 已提交
152
            print(out1) # 2
153
            out2 = paddle.argmax(x, axis=0)
154
            print(out2)
155
            # [2, 2, 0, 1]
W
wawltor 已提交
156
            out3 = paddle.argmax(x, axis=-1)
157
            print(out3)
158 159 160 161
            # [2, 3, 1]
            out4 = paddle.argmax(x, axis=0, keepdim=True)
            print(out4)
            # [[2, 2, 0, 1]]
162
    """
163
    if axis is not None and not isinstance(axis, (int, Variable)):
164
        raise TypeError(
165
            "The type of 'axis'  must be int or Tensor or None in argmax, but received %s."
166 167
            % (type(axis))
        )
168

169 170 171 172
    if dtype is None:
        raise ValueError(
            "the value of 'dtype' in argmax could not be None, but received None"
        )
173

174
    var_dtype = convert_np_dtype_to_dtype_(dtype)
W
wawltor 已提交
175 176 177 178 179
    flatten = False
    if axis is None:
        flatten = True
        axis = 0

H
hong 已提交
180
    if in_dygraph_mode():
181
        return _C_ops.argmax(x, axis, keepdim, flatten, var_dtype)
H
hong 已提交
182
    if _in_legacy_dygraph():
183 184 185 186 187 188 189 190 191 192 193
        out = _legacy_C_ops.arg_max(
            x,
            'axis',
            axis,
            'dtype',
            var_dtype,
            'keepdims',
            keepdim,
            'flatten',
            flatten,
        )
W
wawltor 已提交
194 195 196 197
        return out

    helper = LayerHelper("argmax", **locals())
    check_variable_and_dtype(
198 199 200 201 202
        x,
        'x',
        ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
        'paddle.argmax',
    )
203
    check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
204
    attrs = {}
W
wawltor 已提交
205 206 207 208
    out = helper.create_variable_for_type_inference(var_dtype)
    attrs['keepdims'] = keepdim
    attrs['axis'] = axis
    attrs['flatten'] = flatten
209
    attrs['dtype'] = var_dtype
210 211 212
    helper.append_op(
        type='arg_max', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs
    )
W
wawltor 已提交
213 214 215 216
    out.stop_gradient = True
    return out


217
def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
W
wawltor 已提交
218
    """
219
    Computes the indices of the min elements of the input tensor's
W
wawltor 已提交
220 221 222 223 224 225 226 227
    element along the provided axis.

    Args:
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
228
        keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
229
        dtype(str, optional): Data type of the output tensor which can
230
                    be int32, int64. The default value is 'int64', and it will
W
wawltor 已提交
231
                    return the int64 indices.
232
        name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
233

W
wawltor 已提交
234
    Returns:
235
        Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`.
W
wawltor 已提交
236 237 238

    Examples:
        .. code-block:: python
239

W
wawltor 已提交
240 241
            import paddle

242 243 244
            x =  paddle.to_tensor([[5,8,9,5],
                                     [0,0,1,7],
                                     [6,9,2,4]])
W
wawltor 已提交
245
            out1 = paddle.argmin(x)
N
Noel 已提交
246
            print(out1) # 4
247
            out2 = paddle.argmin(x, axis=0)
248
            print(out2)
249
            # [1, 1, 1, 2]
W
wawltor 已提交
250
            out3 = paddle.argmin(x, axis=-1)
251
            print(out3)
252 253 254 255
            # [0, 0, 2]
            out4 = paddle.argmin(x, axis=0, keepdim=True)
            print(out4)
            # [[1, 1, 1, 2]]
W
wawltor 已提交
256
    """
257
    if axis is not None and not isinstance(axis, (int, Variable)):
258
        raise TypeError(
259
            "The type of 'axis'  must be int or Tensor or None in argmin, but received %s."
260 261
            % (type(axis))
        )
262

263 264 265 266
    if dtype is None:
        raise ValueError(
            "the value of 'dtype' in argmin could not be None, but received None"
        )
267

268
    var_dtype = convert_np_dtype_to_dtype_(dtype)
W
wawltor 已提交
269
    flatten = False
270
    if axis is None:
W
wawltor 已提交
271 272 273
        flatten = True
        axis = 0

H
hong 已提交
274
    if in_dygraph_mode():
275
        return _C_ops.argmin(x, axis, keepdim, flatten, var_dtype)
H
hong 已提交
276
    if _in_legacy_dygraph():
277 278 279 280 281 282 283 284 285 286 287
        out = _legacy_C_ops.arg_min(
            x,
            'axis',
            axis,
            'dtype',
            var_dtype,
            'keepdims',
            keepdim,
            'flatten',
            flatten,
        )
W
wawltor 已提交
288 289 290 291
        return out

    helper = LayerHelper("argmin", **locals())
    check_variable_and_dtype(
292 293 294 295 296
        x,
        'x',
        ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
        'paddle.argmin',
    )
297
    check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
W
wawltor 已提交
298
    out = helper.create_variable_for_type_inference(var_dtype)
299
    attrs = {}
W
wawltor 已提交
300
    attrs['keepdims'] = keepdim
301
    attrs['axis'] = axis
W
wawltor 已提交
302
    attrs['flatten'] = flatten
303
    attrs['dtype'] = var_dtype
304 305 306
    helper.append_op(
        type='arg_min', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs
    )
307 308
    out.stop_gradient = True
    return out
309 310


311
def index_select(x, index, axis=0, name=None):
312
    """
S
swtkiwi 已提交
313

314 315 316 317
    Returns a new tensor which indexes the ``input`` tensor along dimension ``axis`` using
    the entries in ``index`` which is a Tensor. The returned tensor has the same number
    of dimensions as the original ``x`` tensor. The dim-th dimension has the same
    size as the length of ``index``; other dimensions have the same size as in the ``x`` tensor.
C
Chengmo 已提交
318

319
    Args:
320 321 322
        x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float32, float64, int32, int64.
        index (Tensor): The 1-D Tensor containing the indices to index. The data type of ``index`` must be int32 or int64.
        axis (int, optional): The dimension in which we index. Default: if None, the ``axis`` is 0.
323
        name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
324 325

    Returns:
326
        Tensor: A Tensor with same data type as ``x``.
327

328 329
    Examples:
        .. code-block:: python
330

331 332
            import paddle

333 334 335 336
            x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
                                  [5.0, 6.0, 7.0, 8.0],
                                  [9.0, 10.0, 11.0, 12.0]])
            index = paddle.to_tensor([0, 1, 1], dtype='int32')
337 338 339 340 341 342 343 344
            out_z1 = paddle.index_select(x=x, index=index)
            #[[1. 2. 3. 4.]
            # [5. 6. 7. 8.]
            # [5. 6. 7. 8.]]
            out_z2 = paddle.index_select(x=x, index=index, axis=1)
            #[[ 1.  2.  2.]
            # [ 5.  6.  6.]
            # [ 9. 10. 10.]]
345
    """
346

F
From00 已提交
347
    if in_dygraph_mode():
348
        return _C_ops.index_select(x, index, axis)
F
From00 已提交
349 350

    if _in_legacy_dygraph():
351
        return _legacy_C_ops.index_select(x, index, 'dim', axis)
352

353
    helper = LayerHelper("index_select", **locals())
354 355 356 357 358 359 360 361 362
    check_variable_and_dtype(
        x,
        'x',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.tensor.search.index_select',
    )
    check_variable_and_dtype(
        index, 'index', ['int32', 'int64'], 'paddle.tensor.search.index_select'
    )
363

364
    out = helper.create_variable_for_type_inference(x.dtype)
365

366 367 368 369 370 371
    helper.append_op(
        type='index_select',
        inputs={'X': x, 'Index': index},
        outputs={'Out': out},
        attrs={'dim': axis},
    )
372 373 374
    return out


375
def nonzero(x, as_tuple=False):
376
    """
377 378 379 380 381 382
    Return a tensor containing the indices of all non-zero elements of the `input`
    tensor. If as_tuple is True, return a tuple of 1-D tensors, one for each dimension
    in `input`, each containing the indices (in that dimension) of all non-zero elements
    of `input`. Given a n-Dimensional `input` tensor with shape [x_1, x_2, ..., x_n], If
    as_tuple is False, we can get a output tensor with shape [z, n], where `z` is the
    number of all non-zero elements in the `input` tensor. If as_tuple is True, we can get
383
    a 1-D tensor tuple of length `n`, and the shape of each 1-D tensor is [z, 1].
C
Chengmo 已提交
384

385
    Args:
386
        x (Tensor): The input tensor variable.
387 388 389
        as_tuple (bool): Return type, Tensor or tuple of Tensor.

    Returns:
390
        Tensor. The data type is int64.
391 392

    Examples:
393

N
Noel 已提交
394
        .. code-block:: python
李灿 已提交
395

396
            import paddle
397 398

            x1 = paddle.to_tensor([[1.0, 0.0, 0.0],
N
Noel 已提交
399 400
                                   [0.0, 2.0, 0.0],
                                   [0.0, 0.0, 3.0]])
401 402
            x2 = paddle.to_tensor([0.0, 1.0, 0.0, 3.0])
            out_z1 = paddle.nonzero(x1)
N
Noel 已提交
403
            print(out_z1)
404 405 406 407 408
            #[[0 0]
            # [1 1]
            # [2 2]]
            out_z1_tuple = paddle.nonzero(x1, as_tuple=True)
            for out in out_z1_tuple:
N
Noel 已提交
409
                print(out)
410 411 412 413 414 415 416
            #[[0]
            # [1]
            # [2]]
            #[[0]
            # [1]
            # [2]]
            out_z2 = paddle.nonzero(x2)
N
Noel 已提交
417
            print(out_z2)
418 419 420 421
            #[[1]
            # [3]]
            out_z2_tuple = paddle.nonzero(x2, as_tuple=True)
            for out in out_z2_tuple:
N
Noel 已提交
422
                print(out)
423 424
            #[[1]
            # [3]]
N
Noel 已提交
425

426 427
    """
    list_out = []
428
    shape = x.shape
429 430
    rank = len(shape)

431
    if in_dygraph_mode():
432
        outs = _C_ops.nonzero(x)
433 434
    elif paddle.in_dynamic_mode():
        outs = _legacy_C_ops.where_index(x)
435
    else:
436 437 438
        helper = LayerHelper("where_index", **locals())

        outs = helper.create_variable_for_type_inference(
439 440
            dtype=core.VarDesc.VarType.INT64
        )
441

442 443 444
        helper.append_op(
            type='where_index', inputs={'Condition': x}, outputs={'Out': [outs]}
        )
445 446 447 448 449 450 451 452

    if not as_tuple:
        return outs
    elif rank == 1:
        return tuple([outs])
    else:
        for i in range(rank):
            list_out.append(
453 454
                paddle.slice(outs, axes=[1], starts=[i], ends=[i + 1])
            )
455 456 457
        return tuple(list_out)


458
def sort(x, axis=-1, descending=False, name=None):
459
    """
S
swtkiwi 已提交
460

461
    Sorts the input along the given axis, and returns the sorted output tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
C
Chengmo 已提交
462

463
    Args:
464
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
465 466 467
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is Rank(x). when axis<0, it works the same way
468
            as axis+R. Default is -1.
469 470 471
        descending(bool, optional) : Descending is a flag, if set to true,
            algorithm will sort by descending order, else sort by
            ascending order. Default is false.
472
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
473

474
    Returns:
W
wawltor 已提交
475
        Tensor: sorted tensor(with the same shape and data type as ``x``).
476
    Examples:
N
Noel 已提交
477

478
        .. code-block:: python
N
Noel 已提交
479

480
            import paddle
N
Noel 已提交
481

482 483 484 485 486
            x = paddle.to_tensor([[[5,8,9,5],
                                   [0,0,1,7],
                                   [6,9,2,4]],
                                  [[5,2,4,2],
                                   [4,7,7,9],
487
                                   [1,7,0,6]]],
488
                                 dtype='float32')
489 490 491
            out1 = paddle.sort(x=x, axis=-1)
            out2 = paddle.sort(x=x, axis=0)
            out3 = paddle.sort(x=x, axis=1)
N
Noel 已提交
492
            print(out1)
W
wawltor 已提交
493 494 495 496 497 498
            #[[[5. 5. 8. 9.]
            #  [0. 0. 1. 7.]
            #  [2. 4. 6. 9.]]
            # [[2. 2. 4. 5.]
            #  [4. 7. 7. 9.]
            #  [0. 1. 6. 7.]]]
N
Noel 已提交
499
            print(out2)
500
            #[[[5. 2. 4. 2.]
W
wawltor 已提交
501 502 503 504 505
            #  [0. 0. 1. 7.]
            #  [1. 7. 0. 4.]]
            # [[5. 8. 9. 5.]
            #  [4. 7. 7. 9.]
            #  [6. 9. 2. 6.]]]
N
Noel 已提交
506
            print(out3)
507
            #[[[0. 0. 1. 4.]
W
wawltor 已提交
508 509 510 511 512
            #  [5. 8. 2. 5.]
            #  [6. 9. 9. 7.]]
            # [[1. 2. 0. 2.]
            #  [4. 7. 4. 6.]
            #  [5. 7. 7. 9.]]]
513
    """
514
    if in_dygraph_mode():
515
        outs, _ = _C_ops.argsort(x, axis, descending)
516 517 518
        return outs

    if _in_legacy_dygraph():
519 520 521
        outs, _ = _legacy_C_ops.argsort(
            x, 'axis', axis, 'descending', descending
        )
522
        return outs
523
    helper = LayerHelper("sort", **locals())
524 525 526 527 528 529 530 531 532 533 534 535
    out = helper.create_variable_for_type_inference(
        dtype=x.dtype, stop_gradient=False
    )
    ids = helper.create_variable_for_type_inference(
        VarDesc.VarType.INT64, stop_gradient=True
    )
    helper.append_op(
        type='argsort',
        inputs={'X': x},
        outputs={'Out': out, 'Indices': ids},
        attrs={'axis': axis, 'descending': descending},
    )
W
wawltor 已提交
536
    return out
C
Chengmo 已提交
537 538


539 540
def mode(x, axis=-1, keepdim=False, name=None):
    """
541
    Used to find values and indices of the modes at the optional axis.
542 543 544 545 546 547 548

    Args:
        x(Tensor): Tensor, an input N-D Tensor with type float32, float64, int32, int64.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is -1.
        keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
549
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
550 551 552 553 554 555 556 557 558

    Returns:
        tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.

    Examples:

        .. code-block:: python

           import paddle
559

560 561 562 563 564 565 566 567
           tensor = paddle.to_tensor([[[1,2,2],[2,3,3]],[[0,5,5],[9,9,0]]], dtype=paddle.float32)
           res = paddle.mode(tensor, 2)
           print(res)
           # (Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
           #   [[2., 3.],
           #    [5., 9.]]), Tensor(shape=[2, 2], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
           #   [[1, 1],
           #    [1, 0]]))
568

569
    """
570
    if in_dygraph_mode():
571
        return _C_ops.mode(x, axis, keepdim)
572
    if _in_legacy_dygraph():
573
        return _legacy_C_ops.mode(x, "axis", axis, "keepdim", keepdim)
574 575 576 577 578 579 580 581 582 583

    helper = LayerHelper("mode", **locals())
    inputs = {"X": [x]}
    attrs = {}
    attrs['axis'] = axis
    attrs['keepdim'] = keepdim

    values = helper.create_variable_for_type_inference(dtype=x.dtype)
    indices = helper.create_variable_for_type_inference(dtype="int64")

584 585 586 587 588 589
    helper.append_op(
        type="mode",
        inputs=inputs,
        outputs={"Out": [values], "Indices": [indices]},
        attrs=attrs,
    )
590 591 592 593
    indices.stop_gradient = True
    return values, indices


R
ronnywang 已提交
594
def where(condition, x=None, y=None, name=None):
595
    r"""
596
    Return a Tensor of elements selected from either :attr:`x` or :attr:`y` according to corresponding elements of :attr:`condition`. Concretely,
R
ronnywang 已提交
597

598
    .. math::
C
Chengmo 已提交
599

600 601 602 603 604
        out_i =
        \begin{cases}
        x_i, & \text{if}  \ condition_i \  \text{is} \ True \\
        y_i, & \text{if}  \ condition_i \  \text{is} \ False \\
        \end{cases}.
C
Chengmo 已提交
605

606 607
    Notes:
        ``numpy.where(condition)`` is identical to ``paddle.nonzero(condition, as_tuple=True)``, please refer to :ref:`api_tensor_search_nonzero`.
608

609
    Args:
610 611 612 613
        condition (Tensor): The condition to choose x or y. When True (nonzero), yield x, otherwise yield y.
        x (Tensor|scalar, optional): A Tensor or scalar to choose when the condition is True with data type of float32, float64, int32 or int64. Either both or neither of x and y should be given.
        y (Tensor|scalar, optional): A Tensor or scalar to choose when the condition is False with data type of float32, float64, int32 or int64. Either both or neither of x and y should be given.
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
614

615
    Returns:
616
        Tensor: A Tensor with the same shape as :attr:`condition` and same data type as :attr:`x` and :attr:`y`.
617

618
    Examples:
619

620 621
        .. code-block:: python

622
            import paddle
623

624 625
            x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2])
            y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0])
626

627 628 629
            out = paddle.where(x>1, x, y)
            print(out)
            #out: [1.0, 1.0, 3.2, 1.2]
630

631 632 633 634 635
            out = paddle.where(x>1)
            print(out)
            #out: (Tensor(shape=[2, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
            #            [[2],
            #             [3]]),)
636
    """
R
ronnywang 已提交
637
    if np.isscalar(x):
638
        x = paddle.full([1], x, np.array([x]).dtype.name)
R
ronnywang 已提交
639 640

    if np.isscalar(y):
641
        y = paddle.full([1], y, np.array([y]).dtype.name)
R
ronnywang 已提交
642

R
ronnywang 已提交
643 644 645 646 647 648
    if x is None and y is None:
        return nonzero(condition, as_tuple=True)

    if x is None or y is None:
        raise ValueError("either both or neither of x and y should be given")

Z
zhiboniu 已提交
649
    if not paddle.in_dynamic_mode():
650
        check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
651 652 653 654 655 656
        check_variable_and_dtype(
            x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where'
        )
        check_variable_and_dtype(
            y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where'
        )
657

658
    condition_shape = list(condition.shape)
659 660
    x_shape = list(x.shape)
    y_shape = list(y.shape)
661

662
    if x_shape == y_shape and condition_shape == x_shape:
663 664 665 666
        broadcast_condition = condition
        broadcast_x = x
        broadcast_y = y
    else:
Z
zhiboniu 已提交
667 668 669 670 671 672 673 674 675 676 677 678 679
        zeros_like_x = paddle.zeros_like(x)
        zeros_like_y = paddle.zeros_like(y)
        zeros_like_condition = paddle.zeros_like(condition)
        zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype)
        cast_cond = paddle.cast(condition, x.dtype)

        broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y)
        broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition)
        broadcast_x = paddle.add(x, broadcast_zeros)
        broadcast_y = paddle.add(y, broadcast_zeros)
        broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
        broadcast_condition = paddle.cast(broadcast_condition, 'bool')

J
Jiabin Yang 已提交
680
    if in_dygraph_mode():
681
        return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
682
    else:
J
Jiabin Yang 已提交
683
        if _in_legacy_dygraph():
684 685 686
            return _legacy_C_ops.where(
                broadcast_condition, broadcast_x, broadcast_y
            )
J
Jiabin Yang 已提交
687 688 689 690
        else:
            helper = LayerHelper("where", **locals())
            out = helper.create_variable_for_type_inference(dtype=x.dtype)

691 692 693 694 695 696 697 698 699
            helper.append_op(
                type='where',
                inputs={
                    'Condition': broadcast_condition,
                    'X': broadcast_x,
                    'Y': broadcast_y,
                },
                outputs={'Out': [out]},
            )
700

J
Jiabin Yang 已提交
701
            return out
702 703


C
Chengmo 已提交
704 705 706 707
def index_sample(x, index):
    """
    **IndexSample Layer**

708 709
    IndexSample OP returns the element of the specified location of X,
    and the location is specified by Index.
C
Chengmo 已提交
710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727

    .. code-block:: text


                Given:

                X = [[1, 2, 3, 4, 5],
                     [6, 7, 8, 9, 10]]

                Index = [[0, 1, 3],
                         [0, 2, 4]]

                Then:

                Out = [[1, 2, 4],
                       [6, 8, 10]]

    Args:
728
        x (Tensor): The source input tensor with 2-D shape. Supported data type is
C
Chengmo 已提交
729
            int32, int64, float32, float64.
730
        index (Tensor): The index input tensor with 2-D shape, first dimension should be same with X.
C
Chengmo 已提交
731 732 733
            Data type is int32 or int64.

    Returns:
C
Chengmo 已提交
734
        output (Tensor): The output is a tensor with the same shape as index.
C
Chengmo 已提交
735 736 737 738 739 740

    Examples:

        .. code-block:: python

            import paddle
741 742 743 744 745 746 747 748 749 750 751

            x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
                                  [5.0, 6.0, 7.0, 8.0],
                                  [9.0, 10.0, 11.0, 12.0]], dtype='float32')
            index = paddle.to_tensor([[0, 1, 2],
                                      [1, 2, 3],
                                      [0, 0, 0]], dtype='int32')
            target = paddle.to_tensor([[100, 200, 300, 400],
                                       [500, 600, 700, 800],
                                       [900, 1000, 1100, 1200]], dtype='int32')
            out_z1 = paddle.index_sample(x, index)
N
Noel 已提交
752
            print(out_z1)
753 754 755 756 757 758 759 760
            #[[1. 2. 3.]
            # [6. 7. 8.]
            # [9. 9. 9.]]

            # Use the index of the maximum value by topk op
            # get the value of the element of the corresponding index in other tensors
            top_value, top_index = paddle.topk(x, k=2)
            out_z2 = paddle.index_sample(target, top_index)
N
Noel 已提交
761
            print(top_value)
762 763 764 765
            #[[ 4.  3.]
            # [ 8.  7.]
            # [12. 11.]]

N
Noel 已提交
766
            print(top_index)
767 768 769 770
            #[[3 2]
            # [3 2]
            # [3 2]]

N
Noel 已提交
771
            print(out_z2)
772 773 774
            #[[ 400  300]
            # [ 800  700]
            # [1200 1100]]
C
Chengmo 已提交
775

C
Chengmo 已提交
776
    """
J
Jiabin Yang 已提交
777
    if in_dygraph_mode():
778
        return _C_ops.index_sample(x, index)
J
Jiabin Yang 已提交
779 780
    else:
        if _in_legacy_dygraph():
781
            return _legacy_C_ops.index_sample(x, index)
J
Jiabin Yang 已提交
782 783
        else:
            helper = LayerHelper("index_sample", **locals())
784 785 786 787 788 789 790 791 792 793 794 795
            check_variable_and_dtype(
                x,
                'x',
                ['float32', 'float64', 'int32', 'int64'],
                'paddle.tensor.search.index_sample',
            )
            check_variable_and_dtype(
                index,
                'index',
                ['int32', 'int64'],
                'paddle.tensor.search.index_sample',
            )
J
Jiabin Yang 已提交
796 797
            out = helper.create_variable_for_type_inference(dtype=x.dtype)

798 799 800 801 802
            helper.append_op(
                type='index_sample',
                inputs={'X': x, 'Index': index},
                outputs={'Out': out},
            )
J
Jiabin Yang 已提交
803
            return out
804 805 806 807


def masked_select(x, mask, name=None):
    """
C
Chen Long 已提交
808
    Returns a new 1-D tensor which indexes the input tensor according to the ``mask``
809 810 811
    which is a tensor with data type of bool.

    Args:
812
        x (Tensor): The input Tensor, the data type can be int32, int64, float32, float64.
813
        mask (Tensor): The Tensor containing the binary mask to index with, it's data type is bool.
814
        name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
815

816
    Returns:
817
        A 1-D Tensor which is the same data type  as ``x``.
818

819 820 821 822 823
    Examples:

        .. code-block:: python

            import paddle
824 825 826 827 828 829 830

            x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
                                  [5.0, 6.0, 7.0, 8.0],
                                  [9.0, 10.0, 11.0, 12.0]])
            mask = paddle.to_tensor([[True, False, False, False],
                                     [True, True, False, False],
                                     [True, False, False, False]])
831 832 833 834
            out = paddle.masked_select(x, mask)
            #[1.0 5.0 6.0 9.0]
    """

H
hong 已提交
835
    if in_dygraph_mode():
836
        return _C_ops.masked_select(x, mask)
H
hong 已提交
837 838

    if _in_legacy_dygraph():
839
        return _legacy_C_ops.masked_select(x, mask)
840 841

    helper = LayerHelper("masked_select", **locals())
842 843 844 845 846 847 848 849 850
    check_variable_and_dtype(
        x,
        'x',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.tensor.search.mask_select',
    )
    check_variable_and_dtype(
        mask, 'mask', ['bool'], 'paddle.tensor.search.masked_select'
    )
851
    out = helper.create_variable_for_type_inference(dtype=x.dtype)
852 853 854
    helper.append_op(
        type='masked_select', inputs={'X': x, 'Mask': mask}, outputs={'Y': out}
    )
855
    return out
W
wawltor 已提交
856 857 858 859


def topk(x, k, axis=None, largest=True, sorted=True, name=None):
    """
860
    Return values and indices of the k largest or smallest at the optional axis.
W
wawltor 已提交
861 862 863 864 865 866 867 868 869 870 871 872
    If the input is a 1-D Tensor, finds the k largest or smallest values and indices.
    If the input is a Tensor with higher rank, this operator computes the top k values and indices along the :attr:`axis`.

    Args:
        x(Tensor): Tensor, an input N-D Tensor with type float32, float64, int32, int64.
        k(int, Tensor): The number of top elements to look for along the axis.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is -1.
        largest(bool, optional) : largest is a flag, if set to true,
            algorithm will sort by descending order, otherwise sort by
            ascending order. Default is True.
873
        sorted(bool, optional): controls whether to return the elements in sorted order, default value is True. In gpu device, it always return the sorted value.
W
wawltor 已提交
874 875 876 877 878 879 880 881
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.

    Examples:

        .. code-block:: python
882

883
            import paddle
W
wawltor 已提交
884

885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901
            data_1 = paddle.to_tensor([1, 4, 5, 7])
            value_1, indices_1 = paddle.topk(data_1, k=1)
            print(value_1) # [7]
            print(indices_1) # [3]

            data_2 = paddle.to_tensor([[1, 4, 5, 7], [2, 6, 2, 5]])
            value_2, indices_2 = paddle.topk(data_2, k=1)
            print(value_2) # [[7], [6]]
            print(indices_2) # [[3], [1]]

            value_3, indices_3 = paddle.topk(data_2, k=1, axis=-1)
            print(value_3) # [[7], [6]]
            print(indices_3) # [[3], [1]]

            value_4, indices_4 = paddle.topk(data_2, k=1, axis=0)
            print(value_4) # [[2, 6, 5, 7]]
            print(indices_4) # [[1, 1, 0, 0]]
W
wawltor 已提交
902 903 904


    """
H
hong 已提交
905

H
hong 已提交
906 907 908
    if in_dygraph_mode():
        if axis == None:
            axis = -1
909
        out, indices = _C_ops.topk(x, k, axis, largest, sorted)
H
hong 已提交
910 911
        return out, indices

H
hong 已提交
912
    if _non_static_mode():
W
wawltor 已提交
913
        if axis is None:
914 915 916
            out, indices = _legacy_C_ops.top_k_v2(
                x, 'k', int(k), 'largest', largest, 'sorted', sorted
            )
W
wawltor 已提交
917
        else:
918 919 920 921 922 923 924 925 926 927 928
            out, indices = _legacy_C_ops.top_k_v2(
                x,
                'k',
                int(k),
                'axis',
                axis,
                'largest',
                largest,
                'sorted',
                sorted,
            )
W
wawltor 已提交
929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
        return out, indices

    helper = LayerHelper("top_k_v2", **locals())
    inputs = {"X": [x]}
    attrs = {}
    if isinstance(k, Variable):
        inputs['K'] = [k]
    else:
        attrs = {'k': k}
    attrs['largest'] = largest
    attrs['sorted'] = sorted
    if axis is not None:
        attrs['axis'] = axis

    values = helper.create_variable_for_type_inference(dtype=x.dtype)
    indices = helper.create_variable_for_type_inference(dtype="int64")

946 947 948 949 950 951
    helper.append_op(
        type="top_k_v2",
        inputs=inputs,
        outputs={"Out": [values], "Indices": [indices]},
        attrs=attrs,
    )
W
wawltor 已提交
952 953
    indices.stop_gradient = True
    return values, indices
Y
Yanxing Shi 已提交
954 955


956 957 958 959 960 961
def bucketize(x, sorted_sequence, out_int32=False, right=False, name=None):
    """
    This API is used to find the index of the corresponding 1D tensor `sorted_sequence` in the innermost dimension based on the given `x`.

    Args:
        x(Tensor): An input N-D tensor value with type int32, int64, float32, float64.
962
        sorted_sequence(Tensor): An input 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension.
963 964
        out_int32(bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64.
        right(bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `x`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension.
965
                               The default value is False and it shows the lower bounds.
966
        name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
967

968
    Returns:
969 970
        Tensor(the same sizes of the `x`), return the tensor of int32 if set :attr:`out_int32` is True, otherwise return the tensor of int64.

971 972 973
    Examples:

        .. code-block:: python
974

975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998
            import paddle

            sorted_sequence = paddle.to_tensor([2, 4, 8, 16], dtype='int32')
            x = paddle.to_tensor([[0, 8, 4, 16], [-1, 2, 8, 4]], dtype='int32')
            out1 = paddle.bucketize(x, sorted_sequence)
            print(out1)
            # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True,
            #        [[0, 2, 1, 3],
            #         [0, 0, 2, 1]])
            out2 = paddle.bucketize(x, sorted_sequence, right=True)
            print(out2)
            # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True,
            #        [[0, 3, 2, 4],
            #         [0, 1, 3, 2]])
            out3 = x.bucketize(sorted_sequence)
            print(out3)
            # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True,
            #        [[0, 2, 1, 3],
            #         [0, 0, 2, 1]])
            out4 = x.bucketize(sorted_sequence, right=True)
            print(out4)
            # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True,
            #        [[0, 3, 2, 4],
            #         [0, 1, 3, 2]])
999

1000
    """
1001 1002 1003 1004 1005 1006
    check_variable_and_dtype(
        sorted_sequence,
        'SortedSequence',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.searchsorted',
    )
1007 1008 1009 1010 1011 1012 1013
    if sorted_sequence.dim() != 1:
        raise ValueError(
            f"sorted_sequence tensor must be 1 dimension, but got dim {sorted_sequence.dim()}"
        )
    return searchsorted(sorted_sequence, x, out_int32, right, name)


1014 1015 1016
def searchsorted(
    sorted_sequence, values, out_int32=False, right=False, name=None
):
Y
Yanxing Shi 已提交
1017
    """
1018
    Find the index of the corresponding `sorted_sequence` in the innermost dimension based on the given `values`.
Y
Yanxing Shi 已提交
1019 1020

    Args:
1021
        sorted_sequence(Tensor): An input N-D or 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension.
Y
Yanxing Shi 已提交
1022 1023 1024
        values(Tensor): An input N-D tensor value with type int32, int64, float32, float64.
        out_int32(bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64.
        right(bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `values`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension.
1025
                               The default value is False and it shows the lower bounds.
1026
        name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1027

Y
Yanxing Shi 已提交
1028
    Returns:
1029 1030
        Tensor(the same sizes of the `values`), return the tensor of int32 if set :attr:`out_int32` is True, otherwise return the tensor of int64.

Y
Yanxing Shi 已提交
1031 1032 1033
    Examples:

        .. code-block:: python
1034

Y
Yanxing Shi 已提交
1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050
            import paddle

            sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9, 11],
                                                [2, 4, 6, 8, 10, 12]], dtype='int32')
            values = paddle.to_tensor([[3, 6, 9, 10], [3, 6, 9, 10]], dtype='int32')
            out1 = paddle.searchsorted(sorted_sequence, values)
            print(out1)
            # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
            #        [[1, 3, 4, 5],
            #         [1, 2, 4, 4]])
            out2 = paddle.searchsorted(sorted_sequence, values, right=True)
            print(out2)
            # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
            #        [[2, 3, 5, 5],
            #         [1, 3, 4, 5]])
            sorted_sequence_1d = paddle.to_tensor([1, 3, 5, 7, 9, 11, 13])
1051
            out3 = paddle.searchsorted(sorted_sequence_1d, values)
Y
Yanxing Shi 已提交
1052 1053 1054 1055
            print(out3)
            # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
            #        [[1, 3, 4, 5],
            #         [1, 3, 4, 5]])
1056

Y
Yanxing Shi 已提交
1057
    """
F
From00 已提交
1058
    if in_dygraph_mode():
1059
        return _C_ops.searchsorted(sorted_sequence, values, out_int32, right)
Y
Yanxing Shi 已提交
1060

F
From00 已提交
1061
    if _in_legacy_dygraph():
1062 1063 1064
        return _legacy_C_ops.searchsorted(
            sorted_sequence, values, "out_int32", out_int32, "right", right
        )
Y
Yanxing Shi 已提交
1065

1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
    check_variable_and_dtype(
        sorted_sequence,
        'SortedSequence',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.searchsorted',
    )
    check_variable_and_dtype(
        values,
        'Values',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.searchsorted',
    )
Y
Yanxing Shi 已提交
1078 1079 1080 1081

    helper = LayerHelper('searchsorted', **locals())
    out_type = 'int32' if out_int32 else 'int64'
    out = helper.create_variable_for_type_inference(dtype=out_type)
1082 1083 1084 1085 1086 1087
    helper.append_op(
        type='searchsorted',
        inputs={'SortedSequence': sorted_sequence, "Values": values},
        outputs={'Out': out},
        attrs={"out_int32": out_int32, "right": right},
    )
Y
Yanxing Shi 已提交
1088 1089

    return out
1090 1091 1092 1093


def kthvalue(x, k, axis=None, keepdim=False, name=None):
    """
1094
    Find values and indices of the k-th smallest at the axis.
1095 1096 1097 1098 1099 1100 1101 1102

    Args:
        x(Tensor): A N-D Tensor with type float32, float64, int32, int64.
        k(int): The k for the k-th smallest number to look for along the axis.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. The default is None. And if the axis is None, it will computed as -1 by default.
        keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
1103
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1104 1105 1106

    Returns:
        tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.
1107

1108 1109 1110
    Examples:

        .. code-block:: python
1111

1112
            import paddle
1113

1114 1115 1116 1117 1118 1119 1120 1121
            x = paddle.randn((2,3,2))
            # Tensor(shape=[2, 3, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
            #       [[[ 0.22954939, -0.01296274],
            #         [ 1.17135799, -0.34493217],
            #         [-0.19550551, -0.17573971]],
            #
            #        [[ 0.15104349, -0.93965352],
            #         [ 0.14745511,  0.98209465],
1122 1123
            #         [ 0.10732264, -0.55859774]]])
            y = paddle.kthvalue(x, 2, 1)
1124 1125 1126 1127 1128 1129
            # (Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
            # [[ 0.22954939, -0.17573971],
            #  [ 0.14745511, -0.55859774]]), Tensor(shape=[2, 2], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
            #  [[0, 2],
            #  [1, 2]]))
    """
1130
    if _non_static_mode():
1131
        if axis is not None:
1132
            if _in_legacy_dygraph():
1133 1134 1135
                return _legacy_C_ops.kthvalue(
                    x, 'k', k, "axis", axis, "keepdim", keepdim
                )
1136
            return _C_ops.kthvalue(x, k, axis, keepdim)
1137
        else:
1138
            if _in_legacy_dygraph():
1139 1140
                return _legacy_C_ops.kthvalue(x, 'k', k, "keepdim", keepdim)
            return _C_ops.kthvalue(x, k, -1, keepdim)
1141 1142 1143 1144 1145 1146 1147 1148 1149

    helper = LayerHelper("kthvalue", **locals())
    inputs = {"X": [x]}
    attrs = {'k': k}
    if axis is not None:
        attrs['axis'] = axis
    values = helper.create_variable_for_type_inference(dtype=x.dtype)
    indices = helper.create_variable_for_type_inference(dtype="int64")

1150 1151 1152 1153 1154 1155
    helper.append_op(
        type="kthvalue",
        inputs=inputs,
        outputs={"Out": [values], "Indices": [indices]},
        attrs=attrs,
    )
1156 1157
    indices.stop_gradient = True
    return values, indices