variable_index.py 38.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import numpy as np
from . import unique_name
from . import core
W
WeiXin 已提交
19
import paddle
20
import warnings
21

22

23 24 25
MAX_INTEGER = 2**31 - 1


W
WeiXin 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
def is_list_tuple(index, contain_type):
    def _is_list_tuple(item):
        if not (isinstance(item, (list, tuple)) or type(item) == contain_type):
            return False
        if isinstance(item, (tuple, list)):
            for s in item:
                if not _is_list_tuple(s):
                    return False
        return True

    if not isinstance(index, (tuple, list)):
        return False
    for s in index:
        if not _is_list_tuple(s):
            return False
    return True


def is_one_dim_list(index, contain_type):
    if isinstance(index, list):
        for i in index:
            if not isinstance(i, contain_type):
                return False
    else:
        return False
    return True


def get_list_index_shape(var_dims, index_dims):
    var_dims_size = len(var_dims)
    index_dims_size = len(index_dims)

    out_dims_size = var_dims_size - index_dims[0] + index_dims_size - 1

    out_dims_shape = [1] * out_dims_size

62
    out_dims_shape[: index_dims_size - 1] = index_dims[1:]
W
WeiXin 已提交
63

64
    out_dims_shape[index_dims_size - 1 :] = var_dims[index_dims[0] :]
W
WeiXin 已提交
65 66 67 68 69 70 71
    return out_dims_shape


class SliceInfo:
    def __init__(self):
        self.pre_shape = None
        self.indexes = []
W
WeiXin 已提交
72
        self.dtype = None
W
WeiXin 已提交
73 74

    def update(self, index):
75
        if is_list_tuple(index, int) or isinstance(
76 77
            index, (paddle.fluid.Variable, np.ndarray)
        ):
W
WeiXin 已提交
78 79 80 81
            # convert index to Tensor
            if not isinstance(index, paddle.fluid.Variable):
                index = paddle.assign(index)

W
WeiXin 已提交
82 83 84 85 86
            if self.dtype is None:
                self.dtype = index.dtype
            else:
                if index.dtype != self.dtype:
                    raise IndexError(
87 88 89 90
                        "Data type of Tensor/List index should be same. The current data type is {}, but the previous data type is {}.".format(
                            index.dtype, self.dtype
                        )
                    )
W
WeiXin 已提交
91

W
WeiXin 已提交
92 93 94 95 96 97
            self.indexes.append(index)

            if self.pre_shape is None:
                self.pre_shape = index.shape
            else:
                if self.pre_shape != index.shape:
98
                    # broadcast
99 100 101
                    cur_shape = paddle.broadcast_shape(
                        self.pre_shape, index.shape
                    )
W
WeiXin 已提交
102
                    for i in range(len(self.indexes)):
103
                        self.indexes[i] = paddle.broadcast_to(
104 105
                            self.indexes[i], cur_shape
                        )
W
WeiXin 已提交
106 107 108
                self.pre_shape = self.indexes[-1].shape
        else:
            raise ValueError(
109 110 111 112
                "Index should be list/tuple of int or Tensor, but received {}.".format(
                    index
                )
            )
W
WeiXin 已提交
113 114 115 116 117 118 119 120 121

    def shape_stride(self, shape):
        s = [1] * len(shape)
        for i in range(len(shape) - 2, -1, -1):
            s[i] = shape[i + 1] * s[i + 1]

        return s

    def numel(self, shape):
122
        return reduce(lambda x, y: x * y, shape, 1)
W
WeiXin 已提交
123 124 125 126 127 128

    def get_offset_stride(self, tensor_shape):
        for index in self.indexes:
            if not isinstance(index, paddle.fluid.Variable):
                raise ValueError(
                    "only support list/tensor index, but received {}.".format(
129 130 131
                        type(index)
                    )
                )
W
WeiXin 已提交
132 133 134

        if len(self.indexes) <= len(tensor_shape) or len(self.indexes) == 1:
            shape = paddle.stack(self.indexes)
135 136 137
            axes = list(range(1, len(self.pre_shape) + 1)) + [
                0,
            ]
W
WeiXin 已提交
138 139 140

        else:
            raise ValueError(
141 142 143 144
                "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
                    len(tensor_shape), self.pre_shape[0]
                )
            )
W
WeiXin 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159

        shape_transpose = paddle.transpose(shape, axes)
        return shape_transpose

    def get_item(self, tensor):
        shape_transpose = self.get_offset_stride(tensor.shape)
        index = paddle.assign(shape_transpose)
        return paddle.gather_nd(tensor, index)

    def set_item(self, tensor_origin, value):
        if not isinstance(value, paddle.fluid.Variable):
            value = paddle.assign(value)
        tensor_type = None

        if tensor_origin.dtype in [
160 161
            core.VarDesc.VarType.FP32,
            core.VarDesc.VarType.FP64,
W
WeiXin 已提交
162 163 164 165 166 167 168 169 170 171 172 173
        ]:
            tensor = tensor_origin
        else:
            tensor_type = tensor_origin.dtype
            tensor = tensor_origin.astype(core.VarDesc.VarType.FP32)

        if value.dtype != tensor.dtype:
            value = value.astype(tensor.dtype)

        shape_transpose = self.get_offset_stride(tensor_origin.shape)
        index = paddle.assign(shape_transpose)

174 175 176 177 178 179 180
        gather_tensor_shape = get_list_index_shape(
            tensor.shape,
            [
                len(self.indexes),
            ]
            + list(self.indexes[-1].shape),
        )
W
WeiXin 已提交
181

182 183 184
        value_dims_bd = [
            1,
        ] * len(gather_tensor_shape)
185
        value_dims_bd[-len(value.shape) :] = list(value.shape)
W
WeiXin 已提交
186 187

        for i in range(len(gather_tensor_shape)):
188
            if not (
189 190
                len(value_dims_bd) == 0
                or value_dims_bd[i] == gather_tensor_shape[i]
191 192 193 194 195 196 197
                or value_dims_bd[i] == 1
            ):
                raise ValueError(
                    "{} can not broadcast into {}".format(
                        value.shape, gather_tensor_shape
                    )
                )
W
WeiXin 已提交
198 199 200

        value_broadcast = paddle.broadcast_to(value, gather_tensor_shape)

201
        value_1d = value_broadcast.reshape(
202 203
            [-1] + gather_tensor_shape[len(index.shape) - 1 :]
        )
W
WeiXin 已提交
204 205 206 207

        index_1d = index.reshape([-1, index.shape[-1]])

        tensor_stride = paddle.assign(
208 209
            self.shape_stride(tensor.shape[: index.shape[-1]])
        )
W
WeiXin 已提交
210 211 212 213 214
        inds = []
        for i in range(index_1d.shape[0]):
            temp = (index_1d[i] * tensor_stride).sum()
            inds.append(temp)
        index_1d = paddle.stack(inds).reshape([-1])
215
        t_reshape = tensor.reshape([-1] + list(tensor.shape[index.shape[-1] :]))
W
WeiXin 已提交
216 217 218
        out = paddle.scatter(t_reshape, index_1d, value_1d)
        if tensor_type is not None:
            out = out.astype(tensor_type)
219 220 221
        tensor_origin = _setitem_impl_(
            tensor_origin, ..., out.reshape(tensor_origin.shape)
        )
W
WeiXin 已提交
222 223 224 225

        return tensor_origin


226 227
def replace_ellipsis(var, item):
    from .framework import Variable
228

229 230 231 232 233 234 235 236 237 238
    # Use slice(None) to replace Ellipsis.
    # For var, var.shape = [3,4,5,6]
    #
    #   var[..., 1:2] -> var[:, :, :, 1:2]
    #   var[0, ...] -> var[0]
    #   var[0, ..., 1:2] -> var[0, :, :, 1:2]

    item = list(item)

    # Remove Variable to skip bug when counting Ellipsis
W
WeiXin 已提交
239
    item_remove_var = [
240 241
        ele
        for ele in item
242
        if not isinstance(ele, (Variable, np.ndarray)) and ele is not None
W
WeiXin 已提交
243
    ]
244 245 246 247 248 249 250 251 252 253 254
    ell_count = item_remove_var.count(Ellipsis)
    if ell_count == 0:
        return item
    elif ell_count > 1:
        raise IndexError("An index can only have a single ellipsis ('...')")

    ell_idx = item.index(Ellipsis)

    if ell_idx == len(item) - 1:
        return item[:-1]
    else:
255 256 257
        item[ell_idx : ell_idx + 1] = [slice(None)] * (
            len(var.shape) - len(item) + item.count(None) + 1
        )
258 259 260 261

    return item


W
WeiXin 已提交
262 263 264 265 266 267 268 269 270 271
def replace_ndarray(item):
    new_item = []
    for slice_item in item:
        if isinstance(slice_item, np.ndarray):
            new_item.append(paddle.assign(slice_item))
        else:
            new_item.append(slice_item)
    return new_item


272 273 274 275 276 277 278 279 280 281 282
def replace_none(item):
    new_item = []
    none_axes = []
    for i, slice_item in enumerate(item):
        if slice_item is None:
            none_axes.append(i)
        else:
            new_item.append(slice_item)
    return new_item, none_axes


283 284
def is_integer_or_scalar_tensor(ele):
    from .framework import Variable
285

286 287 288
    if isinstance(ele, int):
        return True
    elif isinstance(ele, Variable):
J
JYChen 已提交
289 290 291 292 293 294 295 296 297
        # NOTE(zoooo0820): For compatibility, if FLAGS_set_to_1d is set to True,
        # 1-D tensor is still treated as a scalar, which means basic indexing.
        # This will be removed in future.
        if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']:
            if len(ele.shape) == 1 and ele.shape[0] == 1:
                warnings.warn(
                    "1-D Tensor will be treat as advanced indexing in future version. Currently, 1-D Tensor means a scalar, not vector, and please modify it to 0-D Tensor. If advanced indexing is needed, please use `export FLAGS_set_to_1d=False` to set the flag."
                )
                return True
298
        if len(ele.shape) == 0 and ele.dtype != paddle.bool:
299 300 301 302
            return True
    return False


303 304
def is_bool_tensor(ele):
    from .framework import Variable
305

306 307 308 309 310
    if isinstance(ele, Variable) and ele.dtype == paddle.bool:
        return True
    return False


311 312 313
def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags):
    from .framework import Variable

314 315
    if paddle.utils._contain_var(attr):
        inputs[tensor_attr_name] = paddle.utils._convert_to_tensor_list(
316 317
            attr, dtype="int64"
        )
318 319 320 321 322 323 324 325 326 327
        for i, dim in enumerate(attr):
            if isinstance(dim, Variable):
                attrs[attr_name].append(-1)
                infer_flags[i] = -1
            else:
                attrs[attr_name].append(dim)
    else:
        attrs[attr_name] = attr


328
# the item is a tensor of bool
329 330
def get_value_for_bool_tensor(var, item):
    if len(item.shape) > len(var.shape):
331 332 333 334 335
        raise IndexError(
            "The dims of bool index doesn't match indexed array, "
            "the dims of bool index except to be equal or less "
            "than {}, but received {}.".format(len(var.shape), len(item.shape))
        )
336 337 338 339 340
    i = 0
    item_shape = item.shape
    while i < len(item.shape):
        dim_len = item_shape[i]
        if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]:
341
            raise IndexError(
342 343 344 345 346
                "The dimension of bool index doesn't match indexed array along "
                "dimension {}, the target dimension is {}, but received {}.".format(
                    i, var.shape[i], dim_len
                )
            )
347 348
        i += 1
    empty_shape = [0] + list(var.shape[i:])
349 350 351 352

    def idx_not_empty(var, item):
        from ..tensor import gather_nd

353
        bool_2_idx = paddle.nonzero(item == True)
354 355
        return gather_nd(var, bool_2_idx)

356
    from paddle.static.nn import cond
357 358

    return cond(
359 360 361
        item.any(),
        lambda: idx_not_empty(var, item),
        lambda: paddle.empty(empty_shape, var.dtype),
362
    )
363 364


365
def _setitem_for_tensor_array(var, item, value):
366 367 368 369 370 371 372
    """branches for tensor array setitem operation.
    A item can be a:
    (1) int/Variable, which is a simple number/variable such as [1], [-2]
    (2) Slice, which is represented by bounds such as [2:-1]
    (3) Tuple, which includes the above two cases such as [2:-1, 1]
    If item is case (1), we perform paddle.tensor.array_write,
    in other cases, we raise a NotImplementedError.
373
    """
374

375
    from .framework import Variable
376 377

    assert (
378
        not paddle.in_dynamic_mode()
379 380
    ), "setitem for tensor_array must be called in static graph mode."
    if isinstance(item, (Variable, int)):
381
        from paddle.jit.dy2static.variable_trans_func import (
382 383
            to_static_variable,
        )
384 385
        from paddle import cast
        from paddle.tensor import array_write
386

387 388
        item = paddle.cast(to_static_variable(item), dtype='int64')
        value = to_static_variable(value)
389
        return array_write(x=value, i=item, array=var)
390 391
    else:
        raise NotImplementedError(
392 393 394 395
            "Only support __setitem__ by Int/Variable in tensor_array, but gets {}".format(
                type(item)
            )
        )
396 397


398
def _setitem_impl_(var, item, value):
399
    from paddle.fluid import core
400
    from .framework import default_main_program, Variable
401

402 403
    if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        return _setitem_for_tensor_array(var, item, value)
404 405

    inputs = {'Input': var}
W
WeiXin 已提交
406 407 408
    if isinstance(item, list):
        if not is_one_dim_list(item, int):
            item = tuple(item)
409 410
    # 1. Parse item
    if not isinstance(item, tuple):
411
        item = (item,)
412 413 414 415 416 417 418

    decrease_axes = []
    axes = []
    starts = []
    ends = []
    steps = []

W
WeiXin 已提交
419
    item = replace_ndarray(item)
420
    item = replace_ellipsis(var, item)
421
    item, none_axes = replace_none(item)
W
WeiXin 已提交
422
    slice_info = SliceInfo()
Z
zyfncg 已提交
423 424
    dim = 0
    for _, slice_item in enumerate(item):
425 426 427
        if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor(
            slice_item
        ):
428 429 430 431 432 433 434 435 436 437 438
            decrease_axes.append(dim)
            start = slice_item
            end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
            step = 1

        elif isinstance(slice_item, slice):
            start = slice_item.start
            end = slice_item.stop
            step = slice_item.step

            if start is None and end is None and step is None:
Z
zyfncg 已提交
439
                dim += 1
440 441 442 443 444 445 446
                continue

            step = 1 if step is None else step

            if not isinstance(step, Variable) and step == 0:
                raise ValueError(
                    "When assign a value to a paddle.Tensor, step can not be 0, "
447 448
                    "but received step is {}.".format(step)
                )
449 450 451 452 453 454 455 456 457 458 459 460

            if isinstance(step, Variable) and (start is None or end is None):
                raise ValueError(
                    "When assign a value to a paddle.Tensor, it's not supported that "
                    "the start or end is None when the type of step is paddle.Tensor."
                )

            if start is None:
                start = 0 if step > 0 else MAX_INTEGER

            if end is None:
                end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER)
Z
zyfncg 已提交
461 462 463 464 465 466 467
        elif isinstance(slice_item, list):
            if is_list_tuple(slice_item, int):
                slice_info.update(slice_item)
                continue

            for i in slice_item:
                if not isinstance(i, bool):
468 469 470
                    raise TypeError(
                        "Doesn't support {} in index list.".format(type(i))
                    )
Z
zyfncg 已提交
471 472 473

            if len(item) != 1:
                raise IndexError(
474 475 476 477
                    "When index contains a bool list, its length must be 1, but received {}.".format(
                        len(item)
                    )
                )
Z
zyfncg 已提交
478

479
            idx_tensor = paddle.assign(slice_item)
Z
zyfncg 已提交
480 481 482 483 484 485
            return set_value_for_bool_tensor(var, idx_tensor, value)

        elif isinstance(slice_item, Variable):
            if slice_item.dtype == core.VarDesc.VarType.BOOL:
                if len(item) != 1:
                    raise IndexError(
486 487 488 489
                        "When index contains a bool tensor, its length must be 1, but received {}.".format(
                            len(item)
                        )
                    )
Z
zyfncg 已提交
490 491 492 493
                return set_value_for_bool_tensor(var, slice_item, value)
            else:
                slice_info.update(slice_item)
                continue
494 495
        else:
            raise IndexError(
Z
zyfncg 已提交
496
                "Valid index accept int, slice, ellipsis, None, list of bool, Variable, "
497 498
                "but received {}.".format(slice_item)
            )
499 500 501 502 503 504

        axes.append(dim)
        starts.append(start)
        ends.append(end)
        steps.append(step)

Z
zyfncg 已提交
505
        dim += 1
W
WeiXin 已提交
506 507 508
    if slice_info.indexes:
        if len(slice_info.indexes) != len(item):
            raise IndexError(
509 510 511 512
                "Valid index accept int or slice or ellipsis or list, but received {}.".format(
                    item
                )
            )
W
WeiXin 已提交
513
        return slice_info.set_item(var, value)
514 515 516 517 518
    attrs = {
        'axes': axes,
        'starts': starts,
        'ends': ends,
        'steps': steps,
Z
zyfncg 已提交
519
        'decrease_axes': decrease_axes,
520
        'none_axes': none_axes,
521 522
    }

523 524 525 526
    if paddle.utils._contain_var(starts):
        inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list(
            starts
        )
527
        del attrs['starts']
528 529
    if paddle.utils._contain_var(ends):
        inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends)
530
        del attrs['ends']
531 532
    if paddle.utils._contain_var(steps):
        inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps)
533 534 535 536 537 538 539
        del attrs['steps']

    # 2. Parse value
    dtype = var.dtype
    attrs['dtype'] = dtype

    from .data_feeder import convert_dtype
540

541 542
    #  2.1 value is an integer, float or complex
    if isinstance(value, (bool, int, float, complex)):
543 544 545 546 547
        value = np.array([value]).astype(convert_dtype(dtype))

    #  2.2 value is a np.ndarray
    if isinstance(value, np.ndarray):
        shape = list(value.shape)
548 549
        values = value.ravel().tolist()
        attrs["values"] = values
550 551
        attrs["shape"] = shape

W
wanghuancoder 已提交
552
    elif isinstance(value, (Variable, core.eager.Tensor)):
553 554 555 556 557
        inputs["ValueTensor"] = value
    else:
        raise TypeError(
            "Only support to assign an integer, float, numpy.ndarray or "
            "paddle.Tensor to a paddle.Tensor, but received {}".format(
558 559 560
                type(value)
            )
        )
561

562
    if paddle.in_dynamic_mode():
Z
zyfncg 已提交
563
        var._bump_inplace_version()
564 565 566 567
        output = var
    else:
        helper = paddle.fluid.layer_helper.LayerHelper('set_value', **locals())
        output = helper.create_variable_for_type_inference(dtype=var.dtype)
Z
zyfncg 已提交
568

569
    cur_block = default_main_program().current_block()
570 571 572
    cur_block.append_op(
        type="set_value",
        inputs=inputs,
573
        outputs={'Out': output},
574 575 576
        attrs=attrs,
        inplace_map={"Input": "Out"},
    )
577

578 579 580 581 582 583 584 585 586 587 588
    if not paddle.in_dynamic_mode():
        # map var to the new output
        from paddle.jit.dy2static.program_translator import (
            ProgramTranslator,
        )

        ProgramTranslator.get_instance()._params_map.add(
            cur_block.program, var.desc.id(), output
        )

    return output
Z
zyfncg 已提交
589 590


591
# the item is a tensor of bool
Z
zyfncg 已提交
592 593
def set_value_for_bool_tensor(var, item, value):
    if len(item.shape) > len(var.shape):
594 595 596 597 598
        raise IndexError(
            "The dims of bool index doesn't match indexed array, "
            "the dims of bool index except to be equal or less "
            "than {}, but received {}.".format(len(var.shape), len(item.shape))
        )
Z
zyfncg 已提交
599
    for i, dim_len in enumerate(item.shape):
600
        if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]:
Z
zyfncg 已提交
601 602
            raise IndexError(
                "The dimension of bool index doesn't match indexed array along "
603 604 605 606
                "dimension {}, the target dimension is {}, but received {}.".format(
                    i, var.shape[i], dim_len
                )
            )
Z
zyfncg 已提交
607 608 609 610 611 612

    def idx_not_empty(var, item, value):
        from .framework import Variable
        from ..tensor import gather_nd, scatter_nd_add

        if not isinstance(value, Variable):
613
            value = paddle.assign(value).cast(var.dtype)
Z
zyfncg 已提交
614

615
        idx = paddle.nonzero(item)
Z
zyfncg 已提交
616 617 618
        gather_val = gather_nd(var, idx)
        gather_val_new = value - gather_val
        out = scatter_nd_add(var, idx, gather_val_new)
619 620 621 622 623
        var = _setitem_impl_(var, ..., out)
        return var

    def idx_is_empty(var):
        return var
Z
zyfncg 已提交
624

625
    from paddle.static.nn import cond
626

Z
zyfncg 已提交
627
    # If all the bool index is False, just do nothing
628 629 630 631 632
    var = cond(
        item.any(),
        lambda: idx_not_empty(var, item, value),
        lambda: idx_is_empty(var),
    )
Z
zyfncg 已提交
633 634

    return var
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173


def deal_advanced_index(ori_tensor, indices, is_for_setitem):
    """
    Transpose origin Tensor and advanced indices to the front.

    Returns:
        transed_tensor (Tensor): transposed tensor, corresbonding with advanced indices
        transed_index (List): advanced indices transed to the front
        trans_back_dim (List): order of axes to transpose back to original order. Only used in __setitem__.
        pos_of_new_dim (int):  axis of new dim in the result. Only used in __getitem__.
        rank_of_new_dim (int): rank of new dim in the result. Only used in __getitem__.
    """
    transed_dim = []
    transed_index = []

    # These flags indicates whether the result get by gather_nd requires a second transpose.
    # Only used in __getitem__.
    pos_of_new_dim = MAX_INTEGER
    rank_of_new_dim = 1

    for i, indice in enumerate(indices):
        if indice is not None:
            if not is_for_setitem:
                if i == 0:
                    # case 1: advanced indices at axis 0, the new dim will be at first.
                    pos_of_new_dim = 0
                if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1:
                    # case 2: there are not adjacent advanced indices, the new dim will be at first.
                    pos_of_new_dim = 0
                else:
                    pos_of_new_dim = min(pos_of_new_dim, i)
                rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim)
            transed_dim.append(i)
            transed_index.append(indice[1])
    for i in range(ori_tensor.ndim):
        if indices[i] is None:
            transed_dim.append(i)
    transed_tensor = ori_tensor.transpose(transed_dim)

    trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else []

    return (
        transed_tensor,
        transed_index,
        trans_back_dim,
        pos_of_new_dim,
        rank_of_new_dim,
    )


def parse_index(x, indices):
    advanced_index = [None] * 2 * len(x.shape)  # content is (dim, index)
    # for set_value / slice / strided_slice OP
    decrease_axes = []
    axes = []
    starts = []
    ends = []
    steps = []
    use_strided_slice = False
    has_advanced_index = False

    if isinstance(indices, list) and not is_one_dim_list(indices, int):
        indices = tuple(indices)

    if not isinstance(indices, tuple):
        indices = (indices,)

    indices = replace_ndarray(indices)
    indices = replace_ellipsis(x, indices)
    indices, none_axes = replace_none(indices)

    is_tensor_array = (
        hasattr(x, "desc")
        and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
    )

    estimated_dim = 0
    for dim, slice_item in enumerate(indices):
        start, end, step = None, None, None
        if is_integer_or_scalar_tensor(slice_item):
            if (
                not is_tensor_array
                and isinstance(slice_item, int)
                and x.shape[dim] is not None
                and x.shape[dim] >= 0
                and slice_item >= x.shape[dim]
            ):
                # For python, if users write a, b = var, the __getitem__
                # method will iterate through 0, 1, 2 ... until __getitem__
                # throws an IndexError, then stop. The var[0], var[1] will
                # be given to a, b respectively. If more values are given,
                # the unpack size would cause error.
                # We raises IndexError here to support grammar like `a, b = var`
                raise IndexError(
                    "slice_item %d at dim %d should be >= 0 and < x.shape[%d]: %d"
                    % (slice_item, dim, dim, x.shape[dim])
                )
            # not calculate result to reduce call times for slice OP.
            decrease_axes.append(dim)
            start = slice_item
            step = 1
            end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
        elif isinstance(slice_item, bool):
            # single bool is advanced-indexing
            none_axes.append(dim)
            estimated_dim += 1
            advanced_index[estimated_dim] = (
                estimated_dim,
                paddle.to_tensor(slice_item),
            )
            has_advanced_index = True
        elif isinstance(slice_item, slice):
            start = slice_item.start
            end = slice_item.stop
            step = slice_item.step
            estimated_dim += 1

            if start is None and end is None and step is None:
                continue

            step = 1 if step is None else step
            if start is None:
                start = 0 if step > 0 else MAX_INTEGER
            if end is None:
                end = MAX_INTEGER if step > 0 else -1

        elif isinstance(slice_item, (list, tuple)):
            advanced_index[estimated_dim] = (
                estimated_dim,
                paddle.to_tensor(slice_item),
            )

            if (
                advanced_index[estimated_dim][1].dtype == paddle.bool
                and len(slice_item) != x.shape[dim]
            ):
                raise IndexError(
                    "The shape of boolean index {} did not match indexed tensor {} along axis {}".format(
                        len(slice_item), x.shape[dim], dim
                    )
                )

            has_advanced_index = True
            estimated_dim += 1

        elif isinstance(slice_item, paddle.fluid.Variable):
            # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing.
            if slice_item.dtype == paddle.bool:
                if slice_item.ndim == 0:
                    # 0-D bool Tensor, same as single PY-bool.
                    none_axes.append(dim)

                elif slice_item.shape[0] != x.shape[dim]:
                    raise IndexError(
                        "The shape of boolean index {} did not match indexed tensor {} along axis {}".format(
                            slice_item.shape[0], x.shape[dim], dim
                        )
                    )
            advanced_index[estimated_dim] = (estimated_dim, slice_item)
            has_advanced_index = True
            estimated_dim += 1

        else:
            raise IndexError(
                "Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format(
                    slice_item
                )
            )
        if not (start is None or end is None or step is None):
            starts.append(start)
            ends.append(end)
            steps.append(step)
            axes.append(dim)
            use_strided_slice = (
                True
                if (isinstance(step, paddle.fluid.Variable) or step != 1)
                else use_strided_slice
            )
    return (
        starts,
        ends,
        steps,
        axes,
        none_axes,
        decrease_axes,
        advanced_index,
        has_advanced_index,
        use_strided_slice,
    )


def _setitem_static(x, indices, values):
    """
    In dynamic mode, this function will modify the value at input tensor, returning same Tensor as input.
    But it will return a new Tensor with assigned value in static mode.

    Args:
        x(Tensor): Tensor to be set value.
        indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched.
        values(Tensor|Number|Ndarray): values to be assigned to the x.
    """
    from .framework import default_main_program, Variable

    if x.type == paddle.fluid.core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        return _setitem_for_tensor_array(x, indices, values)

    # step1: parsing the index and recording them
    (
        starts,
        ends,
        steps,
        axes,
        none_axes,
        decrease_axes,
        advanced_index,
        has_advanced_index,
        use_strided_slice,
    ) = parse_index(x, indices)

    inputs = {'Input': x}
    attrs = {
        'axes': axes,
        'starts': starts,
        'ends': ends,
        'steps': steps,
        'decrease_axes': decrease_axes,
        'none_axes': none_axes,
    }
    if paddle.utils._contain_var(starts):
        inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list(
            starts
        )
        del attrs['starts']
    if paddle.utils._contain_var(ends):
        inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends)
        del attrs['ends']
    if paddle.utils._contain_var(steps):
        inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps)
        del attrs['steps']

    if not has_advanced_index:
        # step2. Parse values
        dtype = x.dtype
        attrs['dtype'] = dtype

        from .data_feeder import convert_dtype

        if isinstance(values, (bool, int, float, complex)):
            values = np.array([values]).astype(convert_dtype(dtype))

        if isinstance(values, np.ndarray):
            shape = list(values.shape)
            values = values.ravel().tolist()
            attrs["values"] = values
            attrs["shape"] = shape

        elif isinstance(values, Variable):
            inputs["ValueTensor"] = values
        else:
            raise TypeError(
                "Only support to assign an integer, float, numpy.ndarray or "
                "paddle.Tensor to a paddle.Tensor, but received {}".format(
                    type(values)
                )
            )

        # step3.1: Only basic indexing, use OP set_value to set value.
        if paddle.in_dynamic_mode():
            x._bump_inplace_version()
            output = x
        else:
            helper = paddle.fluid.layer_helper.LayerHelper(
                'set_value', **locals()
            )
            output = helper.create_variable_for_type_inference(dtype=x.dtype)
        cur_block = default_main_program().current_block()
        cur_block.append_op(
            type="set_value",
            inputs=inputs,
            outputs={'Out': output},
            attrs=attrs,
            inplace_map={"Input": "Out"},
        )

        if not paddle.in_dynamic_mode():
            # map var to the new output
            paddle.jit.api.ProgramTranslator.get_instance()._params_map.add(
                cur_block.program, x.desc.id(), output
            )
        return output
    else:
        # step3.2: Case for there are advanced indexing.
        #   1. get __getitem__ result of basic indexing;
        #   2. transpose original tensor so that the axis with advanced indexing will come to the first;
        #   3. assign values to the sliced result by index_put OP;
        #   4. transpose back and assign the result to original tensor by set_value OP.

        sub_tensor = get_tensor_with_basic_indexing(
            x,
            axes,
            starts,
            ends,
            steps,
            decrease_axes,
            none_axes,
            use_strided_slice,
        )
        (
            transed_sub_tensor,
            adjusted_advanced_index,
            transback_dim,
            _,
            _,
        ) = deal_advanced_index(sub_tensor, advanced_index, True)
        if not isinstance(values, Variable):
            values = paddle.assign(values).astype(transed_sub_tensor.dtype)
        transed_sub_tensor = transed_sub_tensor.index_put(
            adjusted_advanced_index, values
        )

        # NOTE(zoooo0820): now basic indexing of __getitem__ will return a new Tensor both in dynamic and static mode
        # After strided is ready and basic indexing returns view of Tensor in dynamic mode. The code shoule be changed
        # for dynamic mode.
        if paddle.in_dynamic_mode():
            transed_sub_tensor.index_put_(adjusted_advanced_index, values)
        else:
            transed_sub_tensor = transed_sub_tensor.index_put(
                adjusted_advanced_index, values
            )

        transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)

        inputs["ValueTensor"] = transback_sub_tensor
        if paddle.in_dynamic_mode():
            x._bump_inplace_version()
            output = x
        else:
            helper = paddle.fluid.layer_helper.LayerHelper(
                'set_value', **locals()
            )
            output = helper.create_variable_for_type_inference(dtype=x.dtype)
        cur_block = default_main_program().current_block()
        cur_block.append_op(
            type="set_value",
            inputs=inputs,
            outputs={'Out': output},
            attrs=attrs,
            inplace_map={"Input": "Out"},
        )
        if not paddle.in_dynamic_mode():
            # map var to the new output
            paddle.jit.api.ProgramTranslator.get_instance()._params_map.add(
                cur_block.program, x.desc.id(), output
            )
        return output


def get_tensor_with_basic_indexing(
    x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice
):
    from .dygraph.base import in_declarative_mode

    if in_declarative_mode() and hasattr(x, "is_view_var"):
        x.is_view_var = True

    if len(axes) == 0:
        out = x
    else:
        op_type = "strided_slice" if use_strided_slice else "slice"
        inputs = {'Input': [x]}
        attrs = {
            'axes': axes,
            'starts': [],
            'ends': [],
            'decrease_axis': decrease_axes,
        }
        if use_strided_slice:
            attrs['strides'] = []
        infer_flags = [1] * len(axes)
        deal_attrs(
            attrs, starts, "starts", "StartsTensorList", inputs, infer_flags
        )
        deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags)
        deal_attrs(
            attrs, steps, "strides", "StridesTensorList", inputs, infer_flags
        )
        attrs['infer_flags'] = infer_flags

        if paddle.in_dynamic_mode():
            if "StartsTensorList" in inputs.keys():
                st = inputs['StartsTensorList']
            else:
                st = attrs['starts']
            if "EndsTensorList" in inputs.keys():
                end = inputs['EndsTensorList']
            else:
                end = attrs['ends']
            if "StridesTensorList" in inputs.keys():
                stride = inputs['StridesTensorList']
            else:
                stride = attrs['strides']
            if use_strided_slice:
                out = paddle._C_ops.strided_slice(x, axes, st, end, stride)
                if len(decrease_axes) > 0:
                    out = paddle._C_ops.squeeze(out, decrease_axes)
            else:
                out = paddle._C_ops.slice(
                    x,
                    axes,
                    st,
                    end,
                    attrs['infer_flags'],
                    attrs['decrease_axis'],
                )
        else:
            from .framework import default_main_program

            target_block = default_main_program().current_block()

            slice_out_var = target_block.create_var(
                name=unique_name.generate_with_ignorable_key(
                    x.name + "_" + op_type
                ),
                dtype=x.dtype,
            )
            target_block.append_op(
                type=op_type,
                inputs=inputs,
                outputs={'Out': [slice_out_var]},
                attrs=attrs,
            )
            out = slice_out_var
    # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
    # with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
    # otherwise the output shape will be not correct.
    set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']
    if set_to_1d and len(decrease_axes) == len(x.shape):
        warnings.warn(
            "Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')."
        )
        none_axes = none_axes[1:]

    if len(none_axes) > 0:
        # Deal with cases that decrease_axes is not empty
        # For example:
        # # x.shape: (2,3,4)
        # out = x[0, 0:2, None] # out.shape : (2, 1, 4)
        for idx, axis in enumerate(none_axes):
            l = len([i for i in decrease_axes if i < axis])
            new_axis = axis - l
            none_axes[idx] = new_axis

        out = paddle.unsqueeze(out, axis=none_axes)

    if in_declarative_mode() and hasattr(out, "is_view_var"):
        out.is_view_var = True
    return out


def _getitem_static(x, indices):
    """
    Args:
        x(Tensor): Tensor to be indexing.
        indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched.
    """
    # step1: parsing the index and recording them
    (
        starts,
        ends,
        steps,
        axes,
        none_axes,
        decrease_axes,
        advanced_index,
        has_advanced_index,
        use_strided_slice,
    ) = parse_index(x, indices)

    # step2: Dealing with basic indexing
    out = get_tensor_with_basic_indexing(
        x,
        axes,
        starts,
        ends,
        steps,
        decrease_axes,
        none_axes,
        use_strided_slice,
    )

    # step3: Dealing with advanced indexing
    if has_advanced_index:
        (
            transed_tensor,
            adjusted_advanced_index,
            _,
            pos_of_new_dim,
            rank_of_new_dim,
        ) = deal_advanced_index(out, advanced_index, False)

        # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently
        if (
            len(adjusted_advanced_index) == 1
            and adjusted_advanced_index[0].dtype == paddle.bool
        ):
            # Note: now slice not support 0-size Tensor, so only one bool tensor can return empty 0-size.
            out = get_value_for_bool_tensor(
                transed_tensor, adjusted_advanced_index[0]
            )
        else:
            adjusted_advanced_index = parse_bool_and_broadcast_indices(
                adjusted_advanced_index
            )
            advanced_index_tensor = paddle.stack(
                adjusted_advanced_index, axis=-1
            )
            out = paddle.gather_nd(transed_tensor, advanced_index_tensor)

        if pos_of_new_dim != 0:
            perm = (
                list(range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim))
                + list(range(0, pos_of_new_dim))
                + list(range(pos_of_new_dim + rank_of_new_dim, out.ndim))
            )
            out = out.transpose(perm)

    return out


def parse_bool_and_broadcast_indices(indices):
    # deal with multiple Tensors and translating bool tensor to int tensor.
    # In static mode, bool-tensor cannot be broadcasted since its corressponding int tensor's shape cannot be infered.
    for i, indice in enumerate(indices):
        if indice.dtype == paddle.bool:
            indices[i] = paddle.nonzero(indice)[:, 0]
    if len(indices) > 1:
        indices = paddle.broadcast_tensors(indices)
    return indices