search.py 41.6 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16

# TODO: define searching & indexing functions of a tensor

17
import numpy as np
18

Z
zhiboniu 已提交
19
import paddle
20 21 22
from paddle import _C_ops, _legacy_C_ops
from paddle.common_ops_import import VarDesc, Variable

23
from ..fluid.data_feeder import check_dtype, check_variable_and_dtype
24
from ..fluid.framework import _in_legacy_dygraph
25 26 27 28 29 30 31
from ..framework import (
    LayerHelper,
    _non_static_mode,
    convert_np_dtype_to_dtype_,
    core,
    in_dygraph_mode,
)
32

33 34
# from ..fluid.layers import has_inf  #DEFINE_ALIAS
# from ..fluid.layers import has_nan  #DEFINE_ALIAS
35

36 37
__all__ = []

38

39 40
def argsort(x, axis=-1, descending=False, name=None):
    """
41
    Sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
42 43 44 45 46 47

    Args:
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is Rank(x). when axis<0, it works the same way
C
Chen Long 已提交
48
            as axis+R. Default is -1.
49 50 51
        descending(bool, optional) : Descending is a flag, if set to true,
            algorithm will sort by descending order, else sort by
            ascending order. Default is false.
52
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
53 54 55 56 57 58

    Returns:
        Tensor: sorted indices(with the same shape as ``x``
        and with data type int64).

    Examples:
李灿 已提交
59

60
        .. code-block:: python
李灿 已提交
61

62
            import paddle
63

64 65 66 67 68
            x = paddle.to_tensor([[[5,8,9,5],
                                   [0,0,1,7],
                                   [6,9,2,4]],
                                  [[5,2,4,2],
                                   [4,7,7,9],
69
                                   [1,7,0,6]]],
70
                                dtype='float32')
C
Chen Long 已提交
71 72 73
            out1 = paddle.argsort(x, axis=-1)
            out2 = paddle.argsort(x, axis=0)
            out3 = paddle.argsort(x, axis=1)
74

N
Noel 已提交
75
            print(out1)
W
wawltor 已提交
76 77 78
            #[[[0 3 1 2]
            #  [0 1 2 3]
            #  [2 3 0 1]]
79
            # [[1 3 2 0]
W
wawltor 已提交
80 81
            #  [0 1 2 3]
            #  [2 0 3 1]]]
82

N
Noel 已提交
83
            print(out2)
W
wawltor 已提交
84 85 86 87 88 89
            #[[[0 1 1 1]
            #  [0 0 0 0]
            #  [1 1 1 0]]
            # [[1 0 0 0]
            #  [1 1 1 1]
            #  [0 0 0 1]]]
90

N
Noel 已提交
91
            print(out3)
W
wawltor 已提交
92 93 94 95 96 97
            #[[[1 1 1 2]
            #  [0 0 2 0]
            #  [2 2 0 1]]
            # [[2 0 2 0]
            #  [1 1 0 2]
            #  [0 2 1 1]]]
98
    """
H
hong 已提交
99
    if in_dygraph_mode():
100
        _, ids = _C_ops.argsort(x, axis, descending)
H
hong 已提交
101 102 103
        return ids

    if _in_legacy_dygraph():
104 105 106
        _, ids = _legacy_C_ops.argsort(
            x, 'axis', axis, 'descending', descending
        )
107 108
        return ids
    check_variable_and_dtype(
109 110 111 112 113
        x,
        'x',
        ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
        'argsort',
    )
114 115

    helper = LayerHelper("argsort", **locals())
116 117 118 119 120 121 122 123 124 125 126 127
    out = helper.create_variable_for_type_inference(
        dtype=x.dtype, stop_gradient=True
    )
    ids = helper.create_variable_for_type_inference(
        VarDesc.VarType.INT64, stop_gradient=True
    )
    helper.append_op(
        type='argsort',
        inputs={'X': x},
        outputs={'Out': out, 'Indices': ids},
        attrs={'axis': axis, 'descending': descending},
    )
128 129 130
    return ids


131
def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
132
    """
133
    Computes the indices of the max elements of the input tensor's
134 135 136
    element along the provided axis.

    Args:
W
wawltor 已提交
137
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
138 139
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
W
wawltor 已提交
140 141
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
142
        keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
143
        dtype(str|np.dtype, optional): Data type of the output tensor which can
144
                    be int32, int64. The default value is ``int64`` , and it will
145
                    return the int64 indices.
146
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
147 148

    Returns:
149
        Tensor, return the tensor of int32 if set :attr:`dtype` is int32, otherwise return the tensor of int64.
150 151 152 153

    Examples:
        .. code-block:: python

W
wawltor 已提交
154
            import paddle
155

156 157 158
            x = paddle.to_tensor([[5,8,9,5],
                                 [0,0,1,7],
                                 [6,9,2,4]])
W
wawltor 已提交
159
            out1 = paddle.argmax(x)
N
Noel 已提交
160
            print(out1) # 2
161
            out2 = paddle.argmax(x, axis=0)
162
            print(out2)
163
            # [2, 2, 0, 1]
W
wawltor 已提交
164
            out3 = paddle.argmax(x, axis=-1)
165
            print(out3)
166 167 168 169
            # [2, 3, 1]
            out4 = paddle.argmax(x, axis=0, keepdim=True)
            print(out4)
            # [[2, 2, 0, 1]]
170
    """
171
    if axis is not None and not isinstance(axis, (int, Variable)):
172
        raise TypeError(
173
            "The type of 'axis'  must be int or Tensor or None in argmax, but received %s."
174 175
            % (type(axis))
        )
176

177 178 179 180
    if dtype is None:
        raise ValueError(
            "the value of 'dtype' in argmax could not be None, but received None"
        )
181

182
    var_dtype = convert_np_dtype_to_dtype_(dtype)
W
wawltor 已提交
183 184 185 186 187
    flatten = False
    if axis is None:
        flatten = True
        axis = 0

H
hong 已提交
188
    if in_dygraph_mode():
189
        return _C_ops.argmax(x, axis, keepdim, flatten, var_dtype)
H
hong 已提交
190
    if _in_legacy_dygraph():
191 192 193 194 195 196 197 198 199 200 201
        out = _legacy_C_ops.arg_max(
            x,
            'axis',
            axis,
            'dtype',
            var_dtype,
            'keepdims',
            keepdim,
            'flatten',
            flatten,
        )
W
wawltor 已提交
202 203 204 205
        return out

    helper = LayerHelper("argmax", **locals())
    check_variable_and_dtype(
206 207 208 209 210
        x,
        'x',
        ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
        'paddle.argmax',
    )
211
    check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
212
    attrs = {}
W
wawltor 已提交
213 214 215 216
    out = helper.create_variable_for_type_inference(var_dtype)
    attrs['keepdims'] = keepdim
    attrs['axis'] = axis
    attrs['flatten'] = flatten
217
    attrs['dtype'] = var_dtype
218 219 220
    helper.append_op(
        type='arg_max', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs
    )
W
wawltor 已提交
221 222 223 224
    out.stop_gradient = True
    return out


225
def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
W
wawltor 已提交
226
    """
227
    Computes the indices of the min elements of the input tensor's
W
wawltor 已提交
228 229 230 231 232 233 234 235
    element along the provided axis.

    Args:
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is None, the input `x` will be into the flatten tensor, and selecting the min value index.
236
        keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
237
        dtype(str, optional): Data type of the output tensor which can
238
                    be int32, int64. The default value is 'int64', and it will
W
wawltor 已提交
239
                    return the int64 indices.
240
        name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
241

W
wawltor 已提交
242
    Returns:
243
        Tensor, return the tensor of `int32` if set :attr:`dtype` is `int32`, otherwise return the tensor of `int64`.
W
wawltor 已提交
244 245 246

    Examples:
        .. code-block:: python
247

W
wawltor 已提交
248 249
            import paddle

250 251 252
            x =  paddle.to_tensor([[5,8,9,5],
                                     [0,0,1,7],
                                     [6,9,2,4]])
W
wawltor 已提交
253
            out1 = paddle.argmin(x)
N
Noel 已提交
254
            print(out1) # 4
255
            out2 = paddle.argmin(x, axis=0)
256
            print(out2)
257
            # [1, 1, 1, 2]
W
wawltor 已提交
258
            out3 = paddle.argmin(x, axis=-1)
259
            print(out3)
260 261 262 263
            # [0, 0, 2]
            out4 = paddle.argmin(x, axis=0, keepdim=True)
            print(out4)
            # [[1, 1, 1, 2]]
W
wawltor 已提交
264
    """
265
    if axis is not None and not isinstance(axis, (int, Variable)):
266
        raise TypeError(
267
            "The type of 'axis'  must be int or Tensor or None in argmin, but received %s."
268 269
            % (type(axis))
        )
270

271 272 273 274
    if dtype is None:
        raise ValueError(
            "the value of 'dtype' in argmin could not be None, but received None"
        )
275

276
    var_dtype = convert_np_dtype_to_dtype_(dtype)
W
wawltor 已提交
277
    flatten = False
278
    if axis is None:
W
wawltor 已提交
279 280 281
        flatten = True
        axis = 0

H
hong 已提交
282
    if in_dygraph_mode():
283
        return _C_ops.argmin(x, axis, keepdim, flatten, var_dtype)
H
hong 已提交
284
    if _in_legacy_dygraph():
285 286 287 288 289 290 291 292 293 294 295
        out = _legacy_C_ops.arg_min(
            x,
            'axis',
            axis,
            'dtype',
            var_dtype,
            'keepdims',
            keepdim,
            'flatten',
            flatten,
        )
W
wawltor 已提交
296 297 298 299
        return out

    helper = LayerHelper("argmin", **locals())
    check_variable_and_dtype(
300 301 302 303 304
        x,
        'x',
        ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
        'paddle.argmin',
    )
305
    check_dtype(var_dtype, 'dtype', ['int32', 'int64'], 'argmin')
W
wawltor 已提交
306
    out = helper.create_variable_for_type_inference(var_dtype)
307
    attrs = {}
W
wawltor 已提交
308
    attrs['keepdims'] = keepdim
309
    attrs['axis'] = axis
W
wawltor 已提交
310
    attrs['flatten'] = flatten
311
    attrs['dtype'] = var_dtype
312 313 314
    helper.append_op(
        type='arg_min', inputs={'X': x}, outputs={'Out': [out]}, attrs=attrs
    )
315 316
    out.stop_gradient = True
    return out
317 318


319
def index_select(x, index, axis=0, name=None):
320
    """
S
swtkiwi 已提交
321

322 323 324 325
    Returns a new tensor which indexes the ``input`` tensor along dimension ``axis`` using
    the entries in ``index`` which is a Tensor. The returned tensor has the same number
    of dimensions as the original ``x`` tensor. The dim-th dimension has the same
    size as the length of ``index``; other dimensions have the same size as in the ``x`` tensor.
C
Chengmo 已提交
326

327
    Args:
328 329 330
        x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float32, float64, int32, int64.
        index (Tensor): The 1-D Tensor containing the indices to index. The data type of ``index`` must be int32 or int64.
        axis (int, optional): The dimension in which we index. Default: if None, the ``axis`` is 0.
331
        name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
332 333

    Returns:
334
        Tensor: A Tensor with same data type as ``x``.
335

336 337
    Examples:
        .. code-block:: python
338

339 340
            import paddle

341 342 343 344
            x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
                                  [5.0, 6.0, 7.0, 8.0],
                                  [9.0, 10.0, 11.0, 12.0]])
            index = paddle.to_tensor([0, 1, 1], dtype='int32')
345 346 347 348 349 350 351 352
            out_z1 = paddle.index_select(x=x, index=index)
            #[[1. 2. 3. 4.]
            # [5. 6. 7. 8.]
            # [5. 6. 7. 8.]]
            out_z2 = paddle.index_select(x=x, index=index, axis=1)
            #[[ 1.  2.  2.]
            # [ 5.  6.  6.]
            # [ 9. 10. 10.]]
353
    """
354

F
From00 已提交
355
    if in_dygraph_mode():
356
        return _C_ops.index_select(x, index, axis)
F
From00 已提交
357 358

    if _in_legacy_dygraph():
359
        return _legacy_C_ops.index_select(x, index, 'dim', axis)
360

361
    helper = LayerHelper("index_select", **locals())
362 363 364 365 366 367 368 369 370
    check_variable_and_dtype(
        x,
        'x',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.tensor.search.index_select',
    )
    check_variable_and_dtype(
        index, 'index', ['int32', 'int64'], 'paddle.tensor.search.index_select'
    )
371

372
    out = helper.create_variable_for_type_inference(x.dtype)
373

374 375 376 377 378 379
    helper.append_op(
        type='index_select',
        inputs={'X': x, 'Index': index},
        outputs={'Out': out},
        attrs={'dim': axis},
    )
380 381 382
    return out


383
def nonzero(x, as_tuple=False):
384
    """
385 386 387 388 389 390
    Return a tensor containing the indices of all non-zero elements of the `input`
    tensor. If as_tuple is True, return a tuple of 1-D tensors, one for each dimension
    in `input`, each containing the indices (in that dimension) of all non-zero elements
    of `input`. Given a n-Dimensional `input` tensor with shape [x_1, x_2, ..., x_n], If
    as_tuple is False, we can get a output tensor with shape [z, n], where `z` is the
    number of all non-zero elements in the `input` tensor. If as_tuple is True, we can get
391
    a 1-D tensor tuple of length `n`, and the shape of each 1-D tensor is [z, 1].
C
Chengmo 已提交
392

393
    Args:
394
        x (Tensor): The input tensor variable.
395
        as_tuple (bool, optional): Return type, Tensor or tuple of Tensor.
396 397

    Returns:
398
        Tensor. The data type is int64.
399 400

    Examples:
401

N
Noel 已提交
402
        .. code-block:: python
李灿 已提交
403

404
            import paddle
405 406

            x1 = paddle.to_tensor([[1.0, 0.0, 0.0],
N
Noel 已提交
407 408
                                   [0.0, 2.0, 0.0],
                                   [0.0, 0.0, 3.0]])
409 410
            x2 = paddle.to_tensor([0.0, 1.0, 0.0, 3.0])
            out_z1 = paddle.nonzero(x1)
N
Noel 已提交
411
            print(out_z1)
412 413 414 415 416
            #[[0 0]
            # [1 1]
            # [2 2]]
            out_z1_tuple = paddle.nonzero(x1, as_tuple=True)
            for out in out_z1_tuple:
N
Noel 已提交
417
                print(out)
418 419 420 421 422 423 424
            #[[0]
            # [1]
            # [2]]
            #[[0]
            # [1]
            # [2]]
            out_z2 = paddle.nonzero(x2)
N
Noel 已提交
425
            print(out_z2)
426 427 428 429
            #[[1]
            # [3]]
            out_z2_tuple = paddle.nonzero(x2, as_tuple=True)
            for out in out_z2_tuple:
N
Noel 已提交
430
                print(out)
431 432
            #[[1]
            # [3]]
N
Noel 已提交
433

434 435
    """
    list_out = []
436
    shape = x.shape
437 438
    rank = len(shape)

439
    if in_dygraph_mode():
440
        outs = _C_ops.nonzero(x)
441 442
    elif paddle.in_dynamic_mode():
        outs = _legacy_C_ops.where_index(x)
443
    else:
444 445 446
        helper = LayerHelper("where_index", **locals())

        outs = helper.create_variable_for_type_inference(
447 448
            dtype=core.VarDesc.VarType.INT64
        )
449

450 451 452
        helper.append_op(
            type='where_index', inputs={'Condition': x}, outputs={'Out': [outs]}
        )
453 454 455 456 457 458 459 460

    if not as_tuple:
        return outs
    elif rank == 1:
        return tuple([outs])
    else:
        for i in range(rank):
            list_out.append(
461 462
                paddle.slice(outs, axes=[1], starts=[i], ends=[i + 1])
            )
463 464 465
        return tuple(list_out)


466
def sort(x, axis=-1, descending=False, name=None):
467
    """
S
swtkiwi 已提交
468

469
    Sorts the input along the given axis, and returns the sorted output tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
C
Chengmo 已提交
470

471
    Args:
472
        x(Tensor): An input N-D Tensor with type float32, float64, int16,
473 474 475
            int32, int64, uint8.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is Rank(x). when axis<0, it works the same way
476
            as axis+R. Default is -1.
477 478 479
        descending(bool, optional) : Descending is a flag, if set to true,
            algorithm will sort by descending order, else sort by
            ascending order. Default is false.
480
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
481

482
    Returns:
W
wawltor 已提交
483
        Tensor: sorted tensor(with the same shape and data type as ``x``).
484
    Examples:
N
Noel 已提交
485

486
        .. code-block:: python
N
Noel 已提交
487

488
            import paddle
N
Noel 已提交
489

490 491 492 493 494
            x = paddle.to_tensor([[[5,8,9,5],
                                   [0,0,1,7],
                                   [6,9,2,4]],
                                  [[5,2,4,2],
                                   [4,7,7,9],
495
                                   [1,7,0,6]]],
496
                                 dtype='float32')
497 498 499
            out1 = paddle.sort(x=x, axis=-1)
            out2 = paddle.sort(x=x, axis=0)
            out3 = paddle.sort(x=x, axis=1)
N
Noel 已提交
500
            print(out1)
W
wawltor 已提交
501 502 503 504 505 506
            #[[[5. 5. 8. 9.]
            #  [0. 0. 1. 7.]
            #  [2. 4. 6. 9.]]
            # [[2. 2. 4. 5.]
            #  [4. 7. 7. 9.]
            #  [0. 1. 6. 7.]]]
N
Noel 已提交
507
            print(out2)
508
            #[[[5. 2. 4. 2.]
W
wawltor 已提交
509 510 511 512 513
            #  [0. 0. 1. 7.]
            #  [1. 7. 0. 4.]]
            # [[5. 8. 9. 5.]
            #  [4. 7. 7. 9.]
            #  [6. 9. 2. 6.]]]
N
Noel 已提交
514
            print(out3)
515
            #[[[0. 0. 1. 4.]
W
wawltor 已提交
516 517 518 519 520
            #  [5. 8. 2. 5.]
            #  [6. 9. 9. 7.]]
            # [[1. 2. 0. 2.]
            #  [4. 7. 4. 6.]
            #  [5. 7. 7. 9.]]]
521
    """
522
    if in_dygraph_mode():
523
        outs, _ = _C_ops.argsort(x, axis, descending)
524 525 526
        return outs

    if _in_legacy_dygraph():
527 528 529
        outs, _ = _legacy_C_ops.argsort(
            x, 'axis', axis, 'descending', descending
        )
530
        return outs
531
    helper = LayerHelper("sort", **locals())
532 533 534 535 536 537 538 539 540 541 542 543
    out = helper.create_variable_for_type_inference(
        dtype=x.dtype, stop_gradient=False
    )
    ids = helper.create_variable_for_type_inference(
        VarDesc.VarType.INT64, stop_gradient=True
    )
    helper.append_op(
        type='argsort',
        inputs={'X': x},
        outputs={'Out': out, 'Indices': ids},
        attrs={'axis': axis, 'descending': descending},
    )
W
wawltor 已提交
544
    return out
C
Chengmo 已提交
545 546


547 548
def mode(x, axis=-1, keepdim=False, name=None):
    """
549
    Used to find values and indices of the modes at the optional axis.
550 551 552 553 554 555 556

    Args:
        x(Tensor): Tensor, an input N-D Tensor with type float32, float64, int32, int64.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is -1.
        keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
557
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
558 559 560 561 562 563 564 565 566

    Returns:
        tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.

    Examples:

        .. code-block:: python

           import paddle
567

568 569 570 571 572 573 574 575
           tensor = paddle.to_tensor([[[1,2,2],[2,3,3]],[[0,5,5],[9,9,0]]], dtype=paddle.float32)
           res = paddle.mode(tensor, 2)
           print(res)
           # (Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
           #   [[2., 3.],
           #    [5., 9.]]), Tensor(shape=[2, 2], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
           #   [[1, 1],
           #    [1, 0]]))
576

577
    """
578
    if in_dygraph_mode():
579
        return _C_ops.mode(x, axis, keepdim)
580
    if _in_legacy_dygraph():
581
        return _legacy_C_ops.mode(x, "axis", axis, "keepdim", keepdim)
582 583 584 585 586 587 588 589 590 591

    helper = LayerHelper("mode", **locals())
    inputs = {"X": [x]}
    attrs = {}
    attrs['axis'] = axis
    attrs['keepdim'] = keepdim

    values = helper.create_variable_for_type_inference(dtype=x.dtype)
    indices = helper.create_variable_for_type_inference(dtype="int64")

592 593 594 595 596 597
    helper.append_op(
        type="mode",
        inputs=inputs,
        outputs={"Out": [values], "Indices": [indices]},
        attrs=attrs,
    )
598 599 600 601
    indices.stop_gradient = True
    return values, indices


R
ronnywang 已提交
602
def where(condition, x=None, y=None, name=None):
603
    r"""
604
    Return a Tensor of elements selected from either :attr:`x` or :attr:`y` according to corresponding elements of :attr:`condition`. Concretely,
R
ronnywang 已提交
605

606
    .. math::
C
Chengmo 已提交
607

608 609 610 611 612
        out_i =
        \begin{cases}
        x_i, & \text{if}  \ condition_i \  \text{is} \ True \\
        y_i, & \text{if}  \ condition_i \  \text{is} \ False \\
        \end{cases}.
C
Chengmo 已提交
613

614 615
    Notes:
        ``numpy.where(condition)`` is identical to ``paddle.nonzero(condition, as_tuple=True)``, please refer to :ref:`api_tensor_search_nonzero`.
616

617
    Args:
618 619 620 621
        condition (Tensor): The condition to choose x or y. When True (nonzero), yield x, otherwise yield y.
        x (Tensor|scalar, optional): A Tensor or scalar to choose when the condition is True with data type of float32, float64, int32 or int64. Either both or neither of x and y should be given.
        y (Tensor|scalar, optional): A Tensor or scalar to choose when the condition is False with data type of float32, float64, int32 or int64. Either both or neither of x and y should be given.
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
622

623
    Returns:
624
        Tensor: A Tensor with the same shape as :attr:`condition` and same data type as :attr:`x` and :attr:`y`.
625

626
    Examples:
627

628 629
        .. code-block:: python

630
            import paddle
631

632 633
            x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2])
            y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0])
634

635 636 637
            out = paddle.where(x>1, x, y)
            print(out)
            #out: [1.0, 1.0, 3.2, 1.2]
638

639 640 641 642 643
            out = paddle.where(x>1)
            print(out)
            #out: (Tensor(shape=[2, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
            #            [[2],
            #             [3]]),)
644
    """
R
ronnywang 已提交
645
    if np.isscalar(x):
646
        x = paddle.full([1], x, np.array([x]).dtype.name)
R
ronnywang 已提交
647 648

    if np.isscalar(y):
649
        y = paddle.full([1], y, np.array([y]).dtype.name)
R
ronnywang 已提交
650

R
ronnywang 已提交
651 652 653 654 655 656
    if x is None and y is None:
        return nonzero(condition, as_tuple=True)

    if x is None or y is None:
        raise ValueError("either both or neither of x and y should be given")

Z
zhiboniu 已提交
657
    if not paddle.in_dynamic_mode():
658
        check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
659 660 661 662 663 664
        check_variable_and_dtype(
            x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where'
        )
        check_variable_and_dtype(
            y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where'
        )
665

666
    condition_shape = list(condition.shape)
667 668
    x_shape = list(x.shape)
    y_shape = list(y.shape)
669

670
    if x_shape == y_shape and condition_shape == x_shape:
671 672 673 674
        broadcast_condition = condition
        broadcast_x = x
        broadcast_y = y
    else:
Z
zhiboniu 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687
        zeros_like_x = paddle.zeros_like(x)
        zeros_like_y = paddle.zeros_like(y)
        zeros_like_condition = paddle.zeros_like(condition)
        zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype)
        cast_cond = paddle.cast(condition, x.dtype)

        broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y)
        broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition)
        broadcast_x = paddle.add(x, broadcast_zeros)
        broadcast_y = paddle.add(y, broadcast_zeros)
        broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
        broadcast_condition = paddle.cast(broadcast_condition, 'bool')

J
Jiabin Yang 已提交
688
    if in_dygraph_mode():
689
        return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
690
    else:
J
Jiabin Yang 已提交
691
        if _in_legacy_dygraph():
692 693 694
            return _legacy_C_ops.where(
                broadcast_condition, broadcast_x, broadcast_y
            )
J
Jiabin Yang 已提交
695 696 697 698
        else:
            helper = LayerHelper("where", **locals())
            out = helper.create_variable_for_type_inference(dtype=x.dtype)

699 700 701 702 703 704 705 706 707
            helper.append_op(
                type='where',
                inputs={
                    'Condition': broadcast_condition,
                    'X': broadcast_x,
                    'Y': broadcast_y,
                },
                outputs={'Out': [out]},
            )
708

J
Jiabin Yang 已提交
709
            return out
710 711


C
Chengmo 已提交
712 713 714 715
def index_sample(x, index):
    """
    **IndexSample Layer**

716 717
    IndexSample OP returns the element of the specified location of X,
    and the location is specified by Index.
C
Chengmo 已提交
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735

    .. code-block:: text


                Given:

                X = [[1, 2, 3, 4, 5],
                     [6, 7, 8, 9, 10]]

                Index = [[0, 1, 3],
                         [0, 2, 4]]

                Then:

                Out = [[1, 2, 4],
                       [6, 8, 10]]

    Args:
736
        x (Tensor): The source input tensor with 2-D shape. Supported data type is
737
            int32, int64, float16, float32, float64.
738
        index (Tensor): The index input tensor with 2-D shape, first dimension should be same with X.
C
Chengmo 已提交
739 740 741
            Data type is int32 or int64.

    Returns:
C
Chengmo 已提交
742
        output (Tensor): The output is a tensor with the same shape as index.
C
Chengmo 已提交
743 744 745 746 747 748

    Examples:

        .. code-block:: python

            import paddle
749 750 751 752 753 754 755 756 757 758 759

            x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
                                  [5.0, 6.0, 7.0, 8.0],
                                  [9.0, 10.0, 11.0, 12.0]], dtype='float32')
            index = paddle.to_tensor([[0, 1, 2],
                                      [1, 2, 3],
                                      [0, 0, 0]], dtype='int32')
            target = paddle.to_tensor([[100, 200, 300, 400],
                                       [500, 600, 700, 800],
                                       [900, 1000, 1100, 1200]], dtype='int32')
            out_z1 = paddle.index_sample(x, index)
N
Noel 已提交
760
            print(out_z1)
761 762 763 764 765 766 767 768
            #[[1. 2. 3.]
            # [6. 7. 8.]
            # [9. 9. 9.]]

            # Use the index of the maximum value by topk op
            # get the value of the element of the corresponding index in other tensors
            top_value, top_index = paddle.topk(x, k=2)
            out_z2 = paddle.index_sample(target, top_index)
N
Noel 已提交
769
            print(top_value)
770 771 772 773
            #[[ 4.  3.]
            # [ 8.  7.]
            # [12. 11.]]

N
Noel 已提交
774
            print(top_index)
775 776 777 778
            #[[3 2]
            # [3 2]
            # [3 2]]

N
Noel 已提交
779
            print(out_z2)
780 781 782
            #[[ 400  300]
            # [ 800  700]
            # [1200 1100]]
C
Chengmo 已提交
783

C
Chengmo 已提交
784
    """
J
Jiabin Yang 已提交
785
    if in_dygraph_mode():
786
        return _C_ops.index_sample(x, index)
J
Jiabin Yang 已提交
787 788
    else:
        if _in_legacy_dygraph():
789
            return _legacy_C_ops.index_sample(x, index)
J
Jiabin Yang 已提交
790 791
        else:
            helper = LayerHelper("index_sample", **locals())
792 793 794
            check_variable_and_dtype(
                x,
                'x',
W
wangxiaoning 已提交
795
                ['float16', 'float32', 'float64', 'int32', 'int64'],
796 797 798 799 800 801 802 803
                'paddle.tensor.search.index_sample',
            )
            check_variable_and_dtype(
                index,
                'index',
                ['int32', 'int64'],
                'paddle.tensor.search.index_sample',
            )
J
Jiabin Yang 已提交
804 805
            out = helper.create_variable_for_type_inference(dtype=x.dtype)

806 807 808 809 810
            helper.append_op(
                type='index_sample',
                inputs={'X': x, 'Index': index},
                outputs={'Out': out},
            )
J
Jiabin Yang 已提交
811
            return out
812 813 814 815


def masked_select(x, mask, name=None):
    """
C
Chen Long 已提交
816
    Returns a new 1-D tensor which indexes the input tensor according to the ``mask``
817 818 819
    which is a tensor with data type of bool.

    Args:
820
        x (Tensor): The input Tensor, the data type can be int32, int64, float32, float64.
821
        mask (Tensor): The Tensor containing the binary mask to index with, it's data type is bool.
822
        name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
823

824
    Returns:
825
        A 1-D Tensor which is the same data type  as ``x``.
826

827 828 829 830 831
    Examples:

        .. code-block:: python

            import paddle
832 833 834 835 836 837 838

            x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
                                  [5.0, 6.0, 7.0, 8.0],
                                  [9.0, 10.0, 11.0, 12.0]])
            mask = paddle.to_tensor([[True, False, False, False],
                                     [True, True, False, False],
                                     [True, False, False, False]])
839 840 841 842
            out = paddle.masked_select(x, mask)
            #[1.0 5.0 6.0 9.0]
    """

H
hong 已提交
843
    if in_dygraph_mode():
844
        return _C_ops.masked_select(x, mask)
H
hong 已提交
845 846

    if _in_legacy_dygraph():
847
        return _legacy_C_ops.masked_select(x, mask)
848 849

    helper = LayerHelper("masked_select", **locals())
850 851 852 853 854 855 856 857 858
    check_variable_and_dtype(
        x,
        'x',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.tensor.search.mask_select',
    )
    check_variable_and_dtype(
        mask, 'mask', ['bool'], 'paddle.tensor.search.masked_select'
    )
859
    out = helper.create_variable_for_type_inference(dtype=x.dtype)
860 861 862
    helper.append_op(
        type='masked_select', inputs={'X': x, 'Mask': mask}, outputs={'Y': out}
    )
863
    return out
W
wawltor 已提交
864 865 866 867


def topk(x, k, axis=None, largest=True, sorted=True, name=None):
    """
868
    Return values and indices of the k largest or smallest at the optional axis.
W
wawltor 已提交
869 870 871 872 873 874 875 876 877 878 879 880
    If the input is a 1-D Tensor, finds the k largest or smallest values and indices.
    If the input is a Tensor with higher rank, this operator computes the top k values and indices along the :attr:`axis`.

    Args:
        x(Tensor): Tensor, an input N-D Tensor with type float32, float64, int32, int64.
        k(int, Tensor): The number of top elements to look for along the axis.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. Default is -1.
        largest(bool, optional) : largest is a flag, if set to true,
            algorithm will sort by descending order, otherwise sort by
            ascending order. Default is True.
881
        sorted(bool, optional): controls whether to return the elements in sorted order, default value is True. In gpu device, it always return the sorted value.
W
wawltor 已提交
882 883 884 885 886 887 888 889
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.

    Examples:

        .. code-block:: python
890

891
            import paddle
W
wawltor 已提交
892

893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909
            data_1 = paddle.to_tensor([1, 4, 5, 7])
            value_1, indices_1 = paddle.topk(data_1, k=1)
            print(value_1) # [7]
            print(indices_1) # [3]

            data_2 = paddle.to_tensor([[1, 4, 5, 7], [2, 6, 2, 5]])
            value_2, indices_2 = paddle.topk(data_2, k=1)
            print(value_2) # [[7], [6]]
            print(indices_2) # [[3], [1]]

            value_3, indices_3 = paddle.topk(data_2, k=1, axis=-1)
            print(value_3) # [[7], [6]]
            print(indices_3) # [[3], [1]]

            value_4, indices_4 = paddle.topk(data_2, k=1, axis=0)
            print(value_4) # [[2, 6, 5, 7]]
            print(indices_4) # [[1, 1, 0, 0]]
W
wawltor 已提交
910 911 912


    """
H
hong 已提交
913

H
hong 已提交
914
    if in_dygraph_mode():
915
        if axis is None:
H
hong 已提交
916
            axis = -1
917
        out, indices = _C_ops.topk(x, k, axis, largest, sorted)
H
hong 已提交
918 919
        return out, indices

H
hong 已提交
920
    if _non_static_mode():
W
wawltor 已提交
921
        if axis is None:
922 923 924
            out, indices = _legacy_C_ops.top_k_v2(
                x, 'k', int(k), 'largest', largest, 'sorted', sorted
            )
W
wawltor 已提交
925
        else:
926 927 928 929 930 931 932 933 934 935 936
            out, indices = _legacy_C_ops.top_k_v2(
                x,
                'k',
                int(k),
                'axis',
                axis,
                'largest',
                largest,
                'sorted',
                sorted,
            )
W
wawltor 已提交
937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953
        return out, indices

    helper = LayerHelper("top_k_v2", **locals())
    inputs = {"X": [x]}
    attrs = {}
    if isinstance(k, Variable):
        inputs['K'] = [k]
    else:
        attrs = {'k': k}
    attrs['largest'] = largest
    attrs['sorted'] = sorted
    if axis is not None:
        attrs['axis'] = axis

    values = helper.create_variable_for_type_inference(dtype=x.dtype)
    indices = helper.create_variable_for_type_inference(dtype="int64")

954 955 956 957 958 959
    helper.append_op(
        type="top_k_v2",
        inputs=inputs,
        outputs={"Out": [values], "Indices": [indices]},
        attrs=attrs,
    )
W
wawltor 已提交
960 961
    indices.stop_gradient = True
    return values, indices
Y
Yanxing Shi 已提交
962 963


964 965 966 967 968 969
def bucketize(x, sorted_sequence, out_int32=False, right=False, name=None):
    """
    This API is used to find the index of the corresponding 1D tensor `sorted_sequence` in the innermost dimension based on the given `x`.

    Args:
        x(Tensor): An input N-D tensor value with type int32, int64, float32, float64.
970
        sorted_sequence(Tensor): An input 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension.
971 972
        out_int32(bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64.
        right(bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `x`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension.
973
                               The default value is False and it shows the lower bounds.
974
        name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
975

976
    Returns:
977 978
        Tensor(the same sizes of the `x`), return the tensor of int32 if set :attr:`out_int32` is True, otherwise return the tensor of int64.

979 980 981
    Examples:

        .. code-block:: python
982

983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006
            import paddle

            sorted_sequence = paddle.to_tensor([2, 4, 8, 16], dtype='int32')
            x = paddle.to_tensor([[0, 8, 4, 16], [-1, 2, 8, 4]], dtype='int32')
            out1 = paddle.bucketize(x, sorted_sequence)
            print(out1)
            # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True,
            #        [[0, 2, 1, 3],
            #         [0, 0, 2, 1]])
            out2 = paddle.bucketize(x, sorted_sequence, right=True)
            print(out2)
            # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True,
            #        [[0, 3, 2, 4],
            #         [0, 1, 3, 2]])
            out3 = x.bucketize(sorted_sequence)
            print(out3)
            # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True,
            #        [[0, 2, 1, 3],
            #         [0, 0, 2, 1]])
            out4 = x.bucketize(sorted_sequence, right=True)
            print(out4)
            # Tensor(shape=[2, 4], dtype=int64, place=CPUPlace, stop_gradient=True,
            #        [[0, 3, 2, 4],
            #         [0, 1, 3, 2]])
1007

1008
    """
1009 1010 1011 1012 1013 1014
    check_variable_and_dtype(
        sorted_sequence,
        'SortedSequence',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.searchsorted',
    )
1015 1016 1017 1018 1019 1020 1021
    if sorted_sequence.dim() != 1:
        raise ValueError(
            f"sorted_sequence tensor must be 1 dimension, but got dim {sorted_sequence.dim()}"
        )
    return searchsorted(sorted_sequence, x, out_int32, right, name)


1022 1023 1024
def searchsorted(
    sorted_sequence, values, out_int32=False, right=False, name=None
):
Y
Yanxing Shi 已提交
1025
    """
1026
    Find the index of the corresponding `sorted_sequence` in the innermost dimension based on the given `values`.
Y
Yanxing Shi 已提交
1027 1028

    Args:
1029
        sorted_sequence(Tensor): An input N-D or 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension.
Y
Yanxing Shi 已提交
1030 1031 1032
        values(Tensor): An input N-D tensor value with type int32, int64, float32, float64.
        out_int32(bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64.
        right(bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `values`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension.
1033
                               The default value is False and it shows the lower bounds.
1034
        name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1035

Y
Yanxing Shi 已提交
1036
    Returns:
1037 1038
        Tensor(the same sizes of the `values`), return the tensor of int32 if set :attr:`out_int32` is True, otherwise return the tensor of int64.

Y
Yanxing Shi 已提交
1039 1040 1041
    Examples:

        .. code-block:: python
1042

Y
Yanxing Shi 已提交
1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058
            import paddle

            sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9, 11],
                                                [2, 4, 6, 8, 10, 12]], dtype='int32')
            values = paddle.to_tensor([[3, 6, 9, 10], [3, 6, 9, 10]], dtype='int32')
            out1 = paddle.searchsorted(sorted_sequence, values)
            print(out1)
            # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
            #        [[1, 3, 4, 5],
            #         [1, 2, 4, 4]])
            out2 = paddle.searchsorted(sorted_sequence, values, right=True)
            print(out2)
            # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
            #        [[2, 3, 5, 5],
            #         [1, 3, 4, 5]])
            sorted_sequence_1d = paddle.to_tensor([1, 3, 5, 7, 9, 11, 13])
1059
            out3 = paddle.searchsorted(sorted_sequence_1d, values)
Y
Yanxing Shi 已提交
1060 1061 1062 1063
            print(out3)
            # Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
            #        [[1, 3, 4, 5],
            #         [1, 3, 4, 5]])
1064

Y
Yanxing Shi 已提交
1065
    """
F
From00 已提交
1066
    if in_dygraph_mode():
1067
        return _C_ops.searchsorted(sorted_sequence, values, out_int32, right)
Y
Yanxing Shi 已提交
1068

F
From00 已提交
1069
    if _in_legacy_dygraph():
1070 1071 1072
        return _legacy_C_ops.searchsorted(
            sorted_sequence, values, "out_int32", out_int32, "right", right
        )
Y
Yanxing Shi 已提交
1073

1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085
    check_variable_and_dtype(
        sorted_sequence,
        'SortedSequence',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.searchsorted',
    )
    check_variable_and_dtype(
        values,
        'Values',
        ['float32', 'float64', 'int32', 'int64'],
        'paddle.searchsorted',
    )
Y
Yanxing Shi 已提交
1086 1087 1088 1089

    helper = LayerHelper('searchsorted', **locals())
    out_type = 'int32' if out_int32 else 'int64'
    out = helper.create_variable_for_type_inference(dtype=out_type)
1090 1091 1092 1093 1094 1095
    helper.append_op(
        type='searchsorted',
        inputs={'SortedSequence': sorted_sequence, "Values": values},
        outputs={'Out': out},
        attrs={"out_int32": out_int32, "right": right},
    )
Y
Yanxing Shi 已提交
1096 1097

    return out
1098 1099 1100 1101


def kthvalue(x, k, axis=None, keepdim=False, name=None):
    """
1102
    Find values and indices of the k-th smallest at the axis.
1103 1104 1105 1106 1107 1108 1109 1110

    Args:
        x(Tensor): A N-D Tensor with type float32, float64, int32, int64.
        k(int): The k for the k-th smallest number to look for along the axis.
        axis(int, optional): Axis to compute indices along. The effective range
            is [-R, R), where R is x.ndim. when axis < 0, it works the same way
            as axis + R. The default is None. And if the axis is None, it will computed as -1 by default.
        keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
1111
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1112 1113 1114

    Returns:
        tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.
1115

1116 1117 1118
    Examples:

        .. code-block:: python
1119

1120
            import paddle
1121

1122 1123 1124 1125 1126 1127 1128 1129
            x = paddle.randn((2,3,2))
            # Tensor(shape=[2, 3, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
            #       [[[ 0.22954939, -0.01296274],
            #         [ 1.17135799, -0.34493217],
            #         [-0.19550551, -0.17573971]],
            #
            #        [[ 0.15104349, -0.93965352],
            #         [ 0.14745511,  0.98209465],
1130 1131
            #         [ 0.10732264, -0.55859774]]])
            y = paddle.kthvalue(x, 2, 1)
1132 1133 1134 1135 1136 1137
            # (Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
            # [[ 0.22954939, -0.17573971],
            #  [ 0.14745511, -0.55859774]]), Tensor(shape=[2, 2], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
            #  [[0, 2],
            #  [1, 2]]))
    """
1138
    if _non_static_mode():
1139
        if axis is not None:
1140
            if _in_legacy_dygraph():
1141 1142 1143
                return _legacy_C_ops.kthvalue(
                    x, 'k', k, "axis", axis, "keepdim", keepdim
                )
1144
            return _C_ops.kthvalue(x, k, axis, keepdim)
1145
        else:
1146
            if _in_legacy_dygraph():
1147 1148
                return _legacy_C_ops.kthvalue(x, 'k', k, "keepdim", keepdim)
            return _C_ops.kthvalue(x, k, -1, keepdim)
1149 1150 1151 1152 1153 1154 1155 1156 1157

    helper = LayerHelper("kthvalue", **locals())
    inputs = {"X": [x]}
    attrs = {'k': k}
    if axis is not None:
        attrs['axis'] = axis
    values = helper.create_variable_for_type_inference(dtype=x.dtype)
    indices = helper.create_variable_for_type_inference(dtype="int64")

1158 1159 1160 1161 1162 1163
    helper.append_op(
        type="kthvalue",
        inputs=inputs,
        outputs={"Out": [values], "Indices": [indices]},
        attrs=attrs,
    )
1164 1165
    indices.stop_gradient = True
    return values, indices