variable_index.py 39.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import numpy as np
from . import unique_name
from . import core
W
WeiXin 已提交
19
import paddle
20
import warnings
21

22

23 24 25
MAX_INTEGER = 2**31 - 1


W
WeiXin 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
def is_list_tuple(index, contain_type):
    def _is_list_tuple(item):
        if not (isinstance(item, (list, tuple)) or type(item) == contain_type):
            return False
        if isinstance(item, (tuple, list)):
            for s in item:
                if not _is_list_tuple(s):
                    return False
        return True

    if not isinstance(index, (tuple, list)):
        return False
    for s in index:
        if not _is_list_tuple(s):
            return False
    return True


def is_one_dim_list(index, contain_type):
    if isinstance(index, list):
        for i in index:
            if not isinstance(i, contain_type):
                return False
    else:
        return False
    return True


def get_list_index_shape(var_dims, index_dims):
    var_dims_size = len(var_dims)
    index_dims_size = len(index_dims)

    out_dims_size = var_dims_size - index_dims[0] + index_dims_size - 1

    out_dims_shape = [1] * out_dims_size

62
    out_dims_shape[: index_dims_size - 1] = index_dims[1:]
W
WeiXin 已提交
63

64
    out_dims_shape[index_dims_size - 1 :] = var_dims[index_dims[0] :]
W
WeiXin 已提交
65 66 67 68 69 70 71
    return out_dims_shape


class SliceInfo:
    def __init__(self):
        self.pre_shape = None
        self.indexes = []
W
WeiXin 已提交
72
        self.dtype = None
W
WeiXin 已提交
73 74

    def update(self, index):
75
        if is_list_tuple(index, int) or isinstance(
76
            index, (paddle.base.Variable, np.ndarray)
77
        ):
W
WeiXin 已提交
78
            # convert index to Tensor
79
            if not isinstance(index, paddle.base.Variable):
W
WeiXin 已提交
80 81
                index = paddle.assign(index)

W
WeiXin 已提交
82 83 84 85 86
            if self.dtype is None:
                self.dtype = index.dtype
            else:
                if index.dtype != self.dtype:
                    raise IndexError(
87 88 89 90
                        "Data type of Tensor/List index should be same. The current data type is {}, but the previous data type is {}.".format(
                            index.dtype, self.dtype
                        )
                    )
W
WeiXin 已提交
91

W
WeiXin 已提交
92 93 94 95 96 97
            self.indexes.append(index)

            if self.pre_shape is None:
                self.pre_shape = index.shape
            else:
                if self.pre_shape != index.shape:
98
                    # broadcast
99 100 101
                    cur_shape = paddle.broadcast_shape(
                        self.pre_shape, index.shape
                    )
W
WeiXin 已提交
102
                    for i in range(len(self.indexes)):
103
                        self.indexes[i] = paddle.broadcast_to(
104 105
                            self.indexes[i], cur_shape
                        )
W
WeiXin 已提交
106 107 108
                self.pre_shape = self.indexes[-1].shape
        else:
            raise ValueError(
109 110 111 112
                "Index should be list/tuple of int or Tensor, but received {}.".format(
                    index
                )
            )
W
WeiXin 已提交
113 114 115 116 117 118 119 120 121

    def shape_stride(self, shape):
        s = [1] * len(shape)
        for i in range(len(shape) - 2, -1, -1):
            s[i] = shape[i + 1] * s[i + 1]

        return s

    def numel(self, shape):
122
        return reduce(lambda x, y: x * y, shape, 1)
W
WeiXin 已提交
123 124 125

    def get_offset_stride(self, tensor_shape):
        for index in self.indexes:
126
            if not isinstance(index, paddle.base.Variable):
W
WeiXin 已提交
127 128
                raise ValueError(
                    "only support list/tensor index, but received {}.".format(
129 130 131
                        type(index)
                    )
                )
W
WeiXin 已提交
132 133 134

        if len(self.indexes) <= len(tensor_shape) or len(self.indexes) == 1:
            shape = paddle.stack(self.indexes)
135 136 137
            axes = list(range(1, len(self.pre_shape) + 1)) + [
                0,
            ]
W
WeiXin 已提交
138 139 140

        else:
            raise ValueError(
141 142 143 144
                "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
                    len(tensor_shape), self.pre_shape[0]
                )
            )
W
WeiXin 已提交
145 146 147 148 149 150 151 152 153 154

        shape_transpose = paddle.transpose(shape, axes)
        return shape_transpose

    def get_item(self, tensor):
        shape_transpose = self.get_offset_stride(tensor.shape)
        index = paddle.assign(shape_transpose)
        return paddle.gather_nd(tensor, index)

    def set_item(self, tensor_origin, value):
155
        if not isinstance(value, paddle.base.Variable):
W
WeiXin 已提交
156 157 158 159
            value = paddle.assign(value)
        tensor_type = None

        if tensor_origin.dtype in [
160 161
            core.VarDesc.VarType.FP32,
            core.VarDesc.VarType.FP64,
W
WeiXin 已提交
162 163 164 165 166 167 168 169 170 171 172 173
        ]:
            tensor = tensor_origin
        else:
            tensor_type = tensor_origin.dtype
            tensor = tensor_origin.astype(core.VarDesc.VarType.FP32)

        if value.dtype != tensor.dtype:
            value = value.astype(tensor.dtype)

        shape_transpose = self.get_offset_stride(tensor_origin.shape)
        index = paddle.assign(shape_transpose)

174 175 176 177 178 179 180
        gather_tensor_shape = get_list_index_shape(
            tensor.shape,
            [
                len(self.indexes),
            ]
            + list(self.indexes[-1].shape),
        )
W
WeiXin 已提交
181

182 183 184
        value_dims_bd = [
            1,
        ] * len(gather_tensor_shape)
185
        value_dims_bd[-len(value.shape) :] = list(value.shape)
W
WeiXin 已提交
186 187

        for i in range(len(gather_tensor_shape)):
188
            if not (
189 190
                len(value_dims_bd) == 0
                or value_dims_bd[i] == gather_tensor_shape[i]
191 192 193 194 195 196 197
                or value_dims_bd[i] == 1
            ):
                raise ValueError(
                    "{} can not broadcast into {}".format(
                        value.shape, gather_tensor_shape
                    )
                )
W
WeiXin 已提交
198 199 200

        value_broadcast = paddle.broadcast_to(value, gather_tensor_shape)

201
        value_1d = value_broadcast.reshape(
202 203
            [-1] + gather_tensor_shape[len(index.shape) - 1 :]
        )
W
WeiXin 已提交
204 205 206 207

        index_1d = index.reshape([-1, index.shape[-1]])

        tensor_stride = paddle.assign(
208 209
            self.shape_stride(tensor.shape[: index.shape[-1]])
        )
W
WeiXin 已提交
210 211 212 213 214
        inds = []
        for i in range(index_1d.shape[0]):
            temp = (index_1d[i] * tensor_stride).sum()
            inds.append(temp)
        index_1d = paddle.stack(inds).reshape([-1])
215
        t_reshape = tensor.reshape([-1] + list(tensor.shape[index.shape[-1] :]))
W
WeiXin 已提交
216 217 218
        out = paddle.scatter(t_reshape, index_1d, value_1d)
        if tensor_type is not None:
            out = out.astype(tensor_type)
219 220 221
        tensor_origin = _setitem_impl_(
            tensor_origin, ..., out.reshape(tensor_origin.shape)
        )
W
WeiXin 已提交
222 223 224 225

        return tensor_origin


226 227
def replace_ellipsis(var, item):
    from .framework import Variable
228

229 230 231 232 233 234 235 236 237 238
    # Use slice(None) to replace Ellipsis.
    # For var, var.shape = [3,4,5,6]
    #
    #   var[..., 1:2] -> var[:, :, :, 1:2]
    #   var[0, ...] -> var[0]
    #   var[0, ..., 1:2] -> var[0, :, :, 1:2]

    item = list(item)

    # Remove Variable to skip bug when counting Ellipsis
W
WeiXin 已提交
239
    item_remove_var = [
240 241
        ele
        for ele in item
242
        if not isinstance(ele, (Variable, np.ndarray)) and ele is not None
W
WeiXin 已提交
243
    ]
244 245 246 247 248 249 250 251 252 253 254
    ell_count = item_remove_var.count(Ellipsis)
    if ell_count == 0:
        return item
    elif ell_count > 1:
        raise IndexError("An index can only have a single ellipsis ('...')")

    ell_idx = item.index(Ellipsis)

    if ell_idx == len(item) - 1:
        return item[:-1]
    else:
255 256 257
        item[ell_idx : ell_idx + 1] = [slice(None)] * (
            len(var.shape) - len(item) + item.count(None) + 1
        )
258 259 260 261

    return item


J
JYChen 已提交
262
def replace_ndarray_and_range(item):
W
WeiXin 已提交
263 264 265 266
    new_item = []
    for slice_item in item:
        if isinstance(slice_item, np.ndarray):
            new_item.append(paddle.assign(slice_item))
J
JYChen 已提交
267 268
        elif isinstance(slice_item, range):
            new_item.append(list(slice_item))
W
WeiXin 已提交
269 270 271 272 273
        else:
            new_item.append(slice_item)
    return new_item


274 275 276 277 278 279 280 281 282 283 284
def replace_none(item):
    new_item = []
    none_axes = []
    for i, slice_item in enumerate(item):
        if slice_item is None:
            none_axes.append(i)
        else:
            new_item.append(slice_item)
    return new_item, none_axes


285 286
def is_integer_or_scalar_tensor(ele):
    from .framework import Variable
287

288 289 290
    if isinstance(ele, int):
        return True
    elif isinstance(ele, Variable):
J
JYChen 已提交
291 292 293 294 295 296 297 298 299
        # NOTE(zoooo0820): For compatibility, if FLAGS_set_to_1d is set to True,
        # 1-D tensor is still treated as a scalar, which means basic indexing.
        # This will be removed in future.
        if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']:
            if len(ele.shape) == 1 and ele.shape[0] == 1:
                warnings.warn(
                    "1-D Tensor will be treat as advanced indexing in future version. Currently, 1-D Tensor means a scalar, not vector, and please modify it to 0-D Tensor. If advanced indexing is needed, please use `export FLAGS_set_to_1d=False` to set the flag."
                )
                return True
300
        if len(ele.shape) == 0 and ele.dtype != paddle.bool:
301 302 303 304
            return True
    return False


305 306
def is_bool_tensor(ele):
    from .framework import Variable
307

308 309 310 311 312
    if isinstance(ele, Variable) and ele.dtype == paddle.bool:
        return True
    return False


313 314 315
def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags):
    from .framework import Variable

316 317
    if paddle.utils._contain_var(attr):
        inputs[tensor_attr_name] = paddle.utils._convert_to_tensor_list(
318 319
            attr, dtype="int64"
        )
320 321 322 323 324 325 326 327 328 329
        for i, dim in enumerate(attr):
            if isinstance(dim, Variable):
                attrs[attr_name].append(-1)
                infer_flags[i] = -1
            else:
                attrs[attr_name].append(dim)
    else:
        attrs[attr_name] = attr


330
# the item is a tensor of bool
331 332
def get_value_for_bool_tensor(var, item):
    if len(item.shape) > len(var.shape):
333 334 335 336 337
        raise IndexError(
            "The dims of bool index doesn't match indexed array, "
            "the dims of bool index except to be equal or less "
            "than {}, but received {}.".format(len(var.shape), len(item.shape))
        )
338 339 340 341 342
    i = 0
    item_shape = item.shape
    while i < len(item.shape):
        dim_len = item_shape[i]
        if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]:
343
            raise IndexError(
344 345 346 347 348
                "The dimension of bool index doesn't match indexed array along "
                "dimension {}, the target dimension is {}, but received {}.".format(
                    i, var.shape[i], dim_len
                )
            )
349 350
        i += 1
    empty_shape = [0] + list(var.shape[i:])
351 352 353 354

    def idx_not_empty(var, item):
        from ..tensor import gather_nd

355
        bool_2_idx = paddle.nonzero(item == True)
356 357
        return gather_nd(var, bool_2_idx)

358
    from paddle.static.nn import cond
359 360

    return cond(
361 362 363
        item.any(),
        lambda: idx_not_empty(var, item),
        lambda: paddle.empty(empty_shape, var.dtype),
364
    )
365 366


367
def _setitem_for_tensor_array(var, item, value):
368 369 370 371 372 373 374
    """branches for tensor array setitem operation.
    A item can be a:
    (1) int/Variable, which is a simple number/variable such as [1], [-2]
    (2) Slice, which is represented by bounds such as [2:-1]
    (3) Tuple, which includes the above two cases such as [2:-1, 1]
    If item is case (1), we perform paddle.tensor.array_write,
    in other cases, we raise a NotImplementedError.
375
    """
376

377
    from .framework import Variable
378 379

    assert (
380
        not paddle.in_dynamic_mode()
381 382
    ), "setitem for tensor_array must be called in static graph mode."
    if isinstance(item, (Variable, int)):
383
        from paddle.jit.dy2static.variable_trans_func import (
384 385
            to_static_variable,
        )
386 387
        from paddle import cast
        from paddle.tensor import array_write
388

389 390
        item = paddle.cast(to_static_variable(item), dtype='int64')
        value = to_static_variable(value)
391
        return array_write(x=value, i=item, array=var)
392 393
    else:
        raise NotImplementedError(
394 395 396 397
            "Only support __setitem__ by Int/Variable in tensor_array, but gets {}".format(
                type(item)
            )
        )
398 399


400
def _setitem_impl_(var, item, value):
401
    from paddle.base import core
402
    from .framework import default_main_program, Variable
403

404 405
    if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        return _setitem_for_tensor_array(var, item, value)
406 407

    inputs = {'Input': var}
W
WeiXin 已提交
408 409 410
    if isinstance(item, list):
        if not is_one_dim_list(item, int):
            item = tuple(item)
411 412
    # 1. Parse item
    if not isinstance(item, tuple):
413
        item = (item,)
414 415 416 417 418 419 420

    decrease_axes = []
    axes = []
    starts = []
    ends = []
    steps = []

J
JYChen 已提交
421
    item = replace_ndarray_and_range(item)
422
    item = replace_ellipsis(var, item)
423
    item, none_axes = replace_none(item)
W
WeiXin 已提交
424
    slice_info = SliceInfo()
Z
zyfncg 已提交
425 426
    dim = 0
    for _, slice_item in enumerate(item):
427 428 429
        if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor(
            slice_item
        ):
430 431 432 433 434 435 436 437 438 439 440
            decrease_axes.append(dim)
            start = slice_item
            end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
            step = 1

        elif isinstance(slice_item, slice):
            start = slice_item.start
            end = slice_item.stop
            step = slice_item.step

            if start is None and end is None and step is None:
Z
zyfncg 已提交
441
                dim += 1
442 443 444 445 446 447 448
                continue

            step = 1 if step is None else step

            if not isinstance(step, Variable) and step == 0:
                raise ValueError(
                    "When assign a value to a paddle.Tensor, step can not be 0, "
449 450
                    "but received step is {}.".format(step)
                )
451 452 453 454 455 456 457 458 459 460 461 462

            if isinstance(step, Variable) and (start is None or end is None):
                raise ValueError(
                    "When assign a value to a paddle.Tensor, it's not supported that "
                    "the start or end is None when the type of step is paddle.Tensor."
                )

            if start is None:
                start = 0 if step > 0 else MAX_INTEGER

            if end is None:
                end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER)
Z
zyfncg 已提交
463 464 465 466 467 468 469
        elif isinstance(slice_item, list):
            if is_list_tuple(slice_item, int):
                slice_info.update(slice_item)
                continue

            for i in slice_item:
                if not isinstance(i, bool):
470 471 472
                    raise TypeError(
                        "Doesn't support {} in index list.".format(type(i))
                    )
Z
zyfncg 已提交
473 474 475

            if len(item) != 1:
                raise IndexError(
476 477 478 479
                    "When index contains a bool list, its length must be 1, but received {}.".format(
                        len(item)
                    )
                )
Z
zyfncg 已提交
480

481
            idx_tensor = paddle.assign(slice_item)
Z
zyfncg 已提交
482 483 484 485 486 487
            return set_value_for_bool_tensor(var, idx_tensor, value)

        elif isinstance(slice_item, Variable):
            if slice_item.dtype == core.VarDesc.VarType.BOOL:
                if len(item) != 1:
                    raise IndexError(
488 489 490 491
                        "When index contains a bool tensor, its length must be 1, but received {}.".format(
                            len(item)
                        )
                    )
Z
zyfncg 已提交
492 493 494 495
                return set_value_for_bool_tensor(var, slice_item, value)
            else:
                slice_info.update(slice_item)
                continue
496 497
        else:
            raise IndexError(
Z
zyfncg 已提交
498
                "Valid index accept int, slice, ellipsis, None, list of bool, Variable, "
499 500
                "but received {}.".format(slice_item)
            )
501 502 503 504 505 506

        axes.append(dim)
        starts.append(start)
        ends.append(end)
        steps.append(step)

Z
zyfncg 已提交
507
        dim += 1
W
WeiXin 已提交
508 509 510
    if slice_info.indexes:
        if len(slice_info.indexes) != len(item):
            raise IndexError(
511 512 513 514
                "Valid index accept int or slice or ellipsis or list, but received {}.".format(
                    item
                )
            )
W
WeiXin 已提交
515
        return slice_info.set_item(var, value)
516 517 518 519 520
    attrs = {
        'axes': axes,
        'starts': starts,
        'ends': ends,
        'steps': steps,
Z
zyfncg 已提交
521
        'decrease_axes': decrease_axes,
522
        'none_axes': none_axes,
523 524
    }

525 526 527 528
    if paddle.utils._contain_var(starts):
        inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list(
            starts
        )
529
        del attrs['starts']
530 531
    if paddle.utils._contain_var(ends):
        inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends)
532
        del attrs['ends']
533 534
    if paddle.utils._contain_var(steps):
        inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps)
535 536 537 538 539 540 541
        del attrs['steps']

    # 2. Parse value
    dtype = var.dtype
    attrs['dtype'] = dtype

    from .data_feeder import convert_dtype
542

543 544
    #  2.1 value is an integer, float or complex
    if isinstance(value, (bool, int, float, complex)):
545 546 547 548 549
        value = np.array([value]).astype(convert_dtype(dtype))

    #  2.2 value is a np.ndarray
    if isinstance(value, np.ndarray):
        shape = list(value.shape)
550 551
        values = value.ravel().tolist()
        attrs["values"] = values
552 553
        attrs["shape"] = shape

W
wanghuancoder 已提交
554
    elif isinstance(value, (Variable, core.eager.Tensor)):
555 556 557 558 559
        inputs["ValueTensor"] = value
    else:
        raise TypeError(
            "Only support to assign an integer, float, numpy.ndarray or "
            "paddle.Tensor to a paddle.Tensor, but received {}".format(
560 561 562
                type(value)
            )
        )
563

564
    if paddle.in_dynamic_mode():
Z
zyfncg 已提交
565
        var._bump_inplace_version()
566 567
        output = var
    else:
568
        helper = paddle.base.layer_helper.LayerHelper('set_value', **locals())
569 570 571 572 573 574 575
        if helper.main_program.current_block_idx != 0:
            # not in global block, we should create a global variable.
            output = helper._create_global_variable_for_type_inference(
                dtype=var.dtype
            )
        else:
            output = helper.create_variable_for_type_inference(dtype=var.dtype)
Z
zyfncg 已提交
576

577
    cur_block = default_main_program().current_block()
578 579 580
    cur_block.append_op(
        type="set_value",
        inputs=inputs,
581
        outputs={'Out': output},
582 583 584
        attrs=attrs,
        inplace_map={"Input": "Out"},
    )
585

586 587 588 589 590 591
    if not paddle.in_dynamic_mode():
        # map var to the new output
        from paddle.jit.dy2static.program_translator import (
            ProgramTranslator,
        )

X
xiongkun 已提交
592
        ProgramTranslator.get_instance()._inplace_map.add(
593 594 595 596
            cur_block.program, var.desc.id(), output
        )

    return output
Z
zyfncg 已提交
597 598


599
# the item is a tensor of bool
Z
zyfncg 已提交
600 601
def set_value_for_bool_tensor(var, item, value):
    if len(item.shape) > len(var.shape):
602 603 604 605 606
        raise IndexError(
            "The dims of bool index doesn't match indexed array, "
            "the dims of bool index except to be equal or less "
            "than {}, but received {}.".format(len(var.shape), len(item.shape))
        )
Z
zyfncg 已提交
607
    for i, dim_len in enumerate(item.shape):
608
        if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]:
Z
zyfncg 已提交
609 610
            raise IndexError(
                "The dimension of bool index doesn't match indexed array along "
611 612 613 614
                "dimension {}, the target dimension is {}, but received {}.".format(
                    i, var.shape[i], dim_len
                )
            )
Z
zyfncg 已提交
615 616 617 618 619 620

    def idx_not_empty(var, item, value):
        from .framework import Variable
        from ..tensor import gather_nd, scatter_nd_add

        if not isinstance(value, Variable):
621
            value = paddle.assign(value).cast(var.dtype)
Z
zyfncg 已提交
622

623
        idx = paddle.nonzero(item)
Z
zyfncg 已提交
624 625 626
        gather_val = gather_nd(var, idx)
        gather_val_new = value - gather_val
        out = scatter_nd_add(var, idx, gather_val_new)
627 628 629 630 631
        var = _setitem_impl_(var, ..., out)
        return var

    def idx_is_empty(var):
        return var
Z
zyfncg 已提交
632

633
    from paddle.static.nn import cond
634

Z
zyfncg 已提交
635
    # If all the bool index is False, just do nothing
636 637 638 639 640
    var = cond(
        item.any(),
        lambda: idx_not_empty(var, item, value),
        lambda: idx_is_empty(var),
    )
Z
zyfncg 已提交
641 642

    return var
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710


def deal_advanced_index(ori_tensor, indices, is_for_setitem):
    """
    Transpose origin Tensor and advanced indices to the front.

    Returns:
        transed_tensor (Tensor): transposed tensor, corresbonding with advanced indices
        transed_index (List): advanced indices transed to the front
        trans_back_dim (List): order of axes to transpose back to original order. Only used in __setitem__.
        pos_of_new_dim (int):  axis of new dim in the result. Only used in __getitem__.
        rank_of_new_dim (int): rank of new dim in the result. Only used in __getitem__.
    """
    transed_dim = []
    transed_index = []

    # These flags indicates whether the result get by gather_nd requires a second transpose.
    # Only used in __getitem__.
    pos_of_new_dim = MAX_INTEGER
    rank_of_new_dim = 1

    for i, indice in enumerate(indices):
        if indice is not None:
            if not is_for_setitem:
                if i == 0:
                    # case 1: advanced indices at axis 0, the new dim will be at first.
                    pos_of_new_dim = 0
                if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1:
                    # case 2: there are not adjacent advanced indices, the new dim will be at first.
                    pos_of_new_dim = 0
                else:
                    pos_of_new_dim = min(pos_of_new_dim, i)
                rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim)
            transed_dim.append(i)
            transed_index.append(indice[1])
    for i in range(ori_tensor.ndim):
        if indices[i] is None:
            transed_dim.append(i)
    transed_tensor = ori_tensor.transpose(transed_dim)

    trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else []

    return (
        transed_tensor,
        transed_index,
        trans_back_dim,
        pos_of_new_dim,
        rank_of_new_dim,
    )


def parse_index(x, indices):
    advanced_index = [None] * 2 * len(x.shape)  # content is (dim, index)
    # for set_value / slice / strided_slice OP
    decrease_axes = []
    axes = []
    starts = []
    ends = []
    steps = []
    use_strided_slice = False
    has_advanced_index = False

    if isinstance(indices, list) and not is_one_dim_list(indices, int):
        indices = tuple(indices)

    if not isinstance(indices, tuple):
        indices = (indices,)

J
JYChen 已提交
711
    indices = replace_ndarray_and_range(indices)
712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
    indices = replace_ellipsis(x, indices)
    indices, none_axes = replace_none(indices)

    is_tensor_array = (
        hasattr(x, "desc")
        and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
    )

    estimated_dim = 0
    for dim, slice_item in enumerate(indices):
        start, end, step = None, None, None
        if is_integer_or_scalar_tensor(slice_item):
            if (
                not is_tensor_array
                and isinstance(slice_item, int)
                and x.shape[dim] is not None
                and x.shape[dim] >= 0
                and slice_item >= x.shape[dim]
            ):
                # For python, if users write a, b = var, the __getitem__
                # method will iterate through 0, 1, 2 ... until __getitem__
                # throws an IndexError, then stop. The var[0], var[1] will
                # be given to a, b respectively. If more values are given,
                # the unpack size would cause error.
                # We raises IndexError here to support grammar like `a, b = var`
                raise IndexError(
                    "slice_item %d at dim %d should be >= 0 and < x.shape[%d]: %d"
                    % (slice_item, dim, dim, x.shape[dim])
                )
            # not calculate result to reduce call times for slice OP.
            decrease_axes.append(dim)
            start = slice_item
            step = 1
            end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
        elif isinstance(slice_item, bool):
            # single bool is advanced-indexing
            none_axes.append(dim)
            estimated_dim += 1
            advanced_index[estimated_dim] = (
                estimated_dim,
                paddle.to_tensor(slice_item),
            )
            has_advanced_index = True
        elif isinstance(slice_item, slice):
            start = slice_item.start
            end = slice_item.stop
            step = slice_item.step
            estimated_dim += 1

            if start is None and end is None and step is None:
                continue

            step = 1 if step is None else step
            if start is None:
                start = 0 if step > 0 else MAX_INTEGER
            if end is None:
                end = MAX_INTEGER if step > 0 else -1

        elif isinstance(slice_item, (list, tuple)):
            advanced_index[estimated_dim] = (
                estimated_dim,
                paddle.to_tensor(slice_item),
            )

            if (
                advanced_index[estimated_dim][1].dtype == paddle.bool
                and len(slice_item) != x.shape[dim]
            ):
                raise IndexError(
                    "The shape of boolean index {} did not match indexed tensor {} along axis {}".format(
                        len(slice_item), x.shape[dim], dim
                    )
                )

            has_advanced_index = True
            estimated_dim += 1

789
        elif isinstance(slice_item, paddle.base.Variable):
790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
            # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing.
            if slice_item.dtype == paddle.bool:
                if slice_item.ndim == 0:
                    # 0-D bool Tensor, same as single PY-bool.
                    none_axes.append(dim)

                elif slice_item.shape[0] != x.shape[dim]:
                    raise IndexError(
                        "The shape of boolean index {} did not match indexed tensor {} along axis {}".format(
                            slice_item.shape[0], x.shape[dim], dim
                        )
                    )
            advanced_index[estimated_dim] = (estimated_dim, slice_item)
            has_advanced_index = True
            estimated_dim += 1

        else:
            raise IndexError(
                "Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format(
                    slice_item
                )
            )
        if not (start is None or end is None or step is None):
            starts.append(start)
            ends.append(end)
            steps.append(step)
            axes.append(dim)
            use_strided_slice = (
                True
819
                if (isinstance(step, paddle.base.Variable) or step != 1)
820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846
                else use_strided_slice
            )
    return (
        starts,
        ends,
        steps,
        axes,
        none_axes,
        decrease_axes,
        advanced_index,
        has_advanced_index,
        use_strided_slice,
    )


def _setitem_static(x, indices, values):
    """
    In dynamic mode, this function will modify the value at input tensor, returning same Tensor as input.
    But it will return a new Tensor with assigned value in static mode.

    Args:
        x(Tensor): Tensor to be set value.
        indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched.
        values(Tensor|Number|Ndarray): values to be assigned to the x.
    """
    from .framework import default_main_program, Variable

847
    if x.type == paddle.base.core.VarDesc.VarType.LOD_TENSOR_ARRAY:
848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914
        return _setitem_for_tensor_array(x, indices, values)

    # step1: parsing the index and recording them
    (
        starts,
        ends,
        steps,
        axes,
        none_axes,
        decrease_axes,
        advanced_index,
        has_advanced_index,
        use_strided_slice,
    ) = parse_index(x, indices)

    inputs = {'Input': x}
    attrs = {
        'axes': axes,
        'starts': starts,
        'ends': ends,
        'steps': steps,
        'decrease_axes': decrease_axes,
        'none_axes': none_axes,
    }
    if paddle.utils._contain_var(starts):
        inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list(
            starts
        )
        del attrs['starts']
    if paddle.utils._contain_var(ends):
        inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends)
        del attrs['ends']
    if paddle.utils._contain_var(steps):
        inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps)
        del attrs['steps']

    if not has_advanced_index:
        # step2. Parse values
        dtype = x.dtype
        attrs['dtype'] = dtype

        from .data_feeder import convert_dtype

        if isinstance(values, (bool, int, float, complex)):
            values = np.array([values]).astype(convert_dtype(dtype))

        if isinstance(values, np.ndarray):
            shape = list(values.shape)
            values = values.ravel().tolist()
            attrs["values"] = values
            attrs["shape"] = shape

        elif isinstance(values, Variable):
            inputs["ValueTensor"] = values
        else:
            raise TypeError(
                "Only support to assign an integer, float, numpy.ndarray or "
                "paddle.Tensor to a paddle.Tensor, but received {}".format(
                    type(values)
                )
            )

        # step3.1: Only basic indexing, use OP set_value to set value.
        if paddle.in_dynamic_mode():
            x._bump_inplace_version()
            output = x
        else:
915
            helper = paddle.base.layer_helper.LayerHelper(
916 917
                'set_value', **locals()
            )
918 919 920 921 922 923 924 925 926
            if helper.main_program.current_block_idx != 0:
                # not in global block, we should create a global variable.
                output = helper._create_global_variable_for_type_inference(
                    dtype=x.dtype
                )
            else:
                output = helper.create_variable_for_type_inference(
                    dtype=x.dtype
                )
927 928 929 930 931 932 933 934 935 936 937
        cur_block = default_main_program().current_block()
        cur_block.append_op(
            type="set_value",
            inputs=inputs,
            outputs={'Out': output},
            attrs=attrs,
            inplace_map={"Input": "Out"},
        )

        if not paddle.in_dynamic_mode():
            # map var to the new output
X
xiongkun 已提交
938
            paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988
                cur_block.program, x.desc.id(), output
            )
        return output
    else:
        # step3.2: Case for there are advanced indexing.
        #   1. get __getitem__ result of basic indexing;
        #   2. transpose original tensor so that the axis with advanced indexing will come to the first;
        #   3. assign values to the sliced result by index_put OP;
        #   4. transpose back and assign the result to original tensor by set_value OP.

        sub_tensor = get_tensor_with_basic_indexing(
            x,
            axes,
            starts,
            ends,
            steps,
            decrease_axes,
            none_axes,
            use_strided_slice,
        )
        (
            transed_sub_tensor,
            adjusted_advanced_index,
            transback_dim,
            _,
            _,
        ) = deal_advanced_index(sub_tensor, advanced_index, True)
        if not isinstance(values, Variable):
            values = paddle.assign(values).astype(transed_sub_tensor.dtype)
        transed_sub_tensor = transed_sub_tensor.index_put(
            adjusted_advanced_index, values
        )

        # NOTE(zoooo0820): now basic indexing of __getitem__ will return a new Tensor both in dynamic and static mode
        # After strided is ready and basic indexing returns view of Tensor in dynamic mode. The code shoule be changed
        # for dynamic mode.
        if paddle.in_dynamic_mode():
            transed_sub_tensor.index_put_(adjusted_advanced_index, values)
        else:
            transed_sub_tensor = transed_sub_tensor.index_put(
                adjusted_advanced_index, values
            )

        transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)

        inputs["ValueTensor"] = transback_sub_tensor
        if paddle.in_dynamic_mode():
            x._bump_inplace_version()
            output = x
        else:
989
            helper = paddle.base.layer_helper.LayerHelper(
990 991
                'set_value', **locals()
            )
992 993 994 995 996 997 998 999 1000
            if helper.main_program.current_block_idx != 0:
                # not in global block, we should create a global variable.
                output = helper._create_global_variable_for_type_inference(
                    dtype=x.dtype
                )
            else:
                output = helper.create_variable_for_type_inference(
                    dtype=x.dtype
                )
1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
        cur_block = default_main_program().current_block()
        cur_block.append_op(
            type="set_value",
            inputs=inputs,
            outputs={'Out': output},
            attrs=attrs,
            inplace_map={"Input": "Out"},
        )
        if not paddle.in_dynamic_mode():
            # map var to the new output
X
xiongkun 已提交
1011
            paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
1012 1013 1014 1015 1016 1017 1018 1019
                cur_block.program, x.desc.id(), output
            )
        return output


def get_tensor_with_basic_indexing(
    x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice
):
1020
    from .dygraph.base import in_to_static_mode
1021

1022
    if in_to_static_mode() and hasattr(x, "is_view_var"):
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113
        x.is_view_var = True

    if len(axes) == 0:
        out = x
    else:
        op_type = "strided_slice" if use_strided_slice else "slice"
        inputs = {'Input': [x]}
        attrs = {
            'axes': axes,
            'starts': [],
            'ends': [],
            'decrease_axis': decrease_axes,
        }
        if use_strided_slice:
            attrs['strides'] = []
        infer_flags = [1] * len(axes)
        deal_attrs(
            attrs, starts, "starts", "StartsTensorList", inputs, infer_flags
        )
        deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags)
        deal_attrs(
            attrs, steps, "strides", "StridesTensorList", inputs, infer_flags
        )
        attrs['infer_flags'] = infer_flags

        if paddle.in_dynamic_mode():
            if "StartsTensorList" in inputs.keys():
                st = inputs['StartsTensorList']
            else:
                st = attrs['starts']
            if "EndsTensorList" in inputs.keys():
                end = inputs['EndsTensorList']
            else:
                end = attrs['ends']
            if "StridesTensorList" in inputs.keys():
                stride = inputs['StridesTensorList']
            else:
                stride = attrs['strides']
            if use_strided_slice:
                out = paddle._C_ops.strided_slice(x, axes, st, end, stride)
                if len(decrease_axes) > 0:
                    out = paddle._C_ops.squeeze(out, decrease_axes)
            else:
                out = paddle._C_ops.slice(
                    x,
                    axes,
                    st,
                    end,
                    attrs['infer_flags'],
                    attrs['decrease_axis'],
                )
        else:
            from .framework import default_main_program

            target_block = default_main_program().current_block()

            slice_out_var = target_block.create_var(
                name=unique_name.generate_with_ignorable_key(
                    x.name + "_" + op_type
                ),
                dtype=x.dtype,
            )
            target_block.append_op(
                type=op_type,
                inputs=inputs,
                outputs={'Out': [slice_out_var]},
                attrs=attrs,
            )
            out = slice_out_var
    # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
    # with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
    # otherwise the output shape will be not correct.
    set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']
    if set_to_1d and len(decrease_axes) == len(x.shape):
        warnings.warn(
            "Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')."
        )
        none_axes = none_axes[1:]

    if len(none_axes) > 0:
        # Deal with cases that decrease_axes is not empty
        # For example:
        # # x.shape: (2,3,4)
        # out = x[0, 0:2, None] # out.shape : (2, 1, 4)
        for idx, axis in enumerate(none_axes):
            l = len([i for i in decrease_axes if i < axis])
            new_axis = axis - l
            none_axes[idx] = new_axis

        out = paddle.unsqueeze(out, axis=none_axes)

1114
    if in_to_static_mode() and hasattr(out, "is_view_var"):
1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197
        out.is_view_var = True
    return out


def _getitem_static(x, indices):
    """
    Args:
        x(Tensor): Tensor to be indexing.
        indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched.
    """
    # step1: parsing the index and recording them
    (
        starts,
        ends,
        steps,
        axes,
        none_axes,
        decrease_axes,
        advanced_index,
        has_advanced_index,
        use_strided_slice,
    ) = parse_index(x, indices)

    # step2: Dealing with basic indexing
    out = get_tensor_with_basic_indexing(
        x,
        axes,
        starts,
        ends,
        steps,
        decrease_axes,
        none_axes,
        use_strided_slice,
    )

    # step3: Dealing with advanced indexing
    if has_advanced_index:
        (
            transed_tensor,
            adjusted_advanced_index,
            _,
            pos_of_new_dim,
            rank_of_new_dim,
        ) = deal_advanced_index(out, advanced_index, False)

        # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently
        if (
            len(adjusted_advanced_index) == 1
            and adjusted_advanced_index[0].dtype == paddle.bool
        ):
            # Note: now slice not support 0-size Tensor, so only one bool tensor can return empty 0-size.
            out = get_value_for_bool_tensor(
                transed_tensor, adjusted_advanced_index[0]
            )
        else:
            adjusted_advanced_index = parse_bool_and_broadcast_indices(
                adjusted_advanced_index
            )
            advanced_index_tensor = paddle.stack(
                adjusted_advanced_index, axis=-1
            )
            out = paddle.gather_nd(transed_tensor, advanced_index_tensor)

        if pos_of_new_dim != 0:
            perm = (
                list(range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim))
                + list(range(0, pos_of_new_dim))
                + list(range(pos_of_new_dim + rank_of_new_dim, out.ndim))
            )
            out = out.transpose(perm)

    return out


def parse_bool_and_broadcast_indices(indices):
    # deal with multiple Tensors and translating bool tensor to int tensor.
    # In static mode, bool-tensor cannot be broadcasted since its corressponding int tensor's shape cannot be infered.
    for i, indice in enumerate(indices):
        if indice.dtype == paddle.bool:
            indices[i] = paddle.nonzero(indice)[:, 0]
    if len(indices) > 1:
        indices = paddle.broadcast_tensors(indices)
    return indices