variable_index.py 39.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import numpy as np
from . import unique_name
from . import core
W
WeiXin 已提交
19
import paddle
20
import warnings
21

22

23 24 25
MAX_INTEGER = 2**31 - 1


W
WeiXin 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
def is_list_tuple(index, contain_type):
    def _is_list_tuple(item):
        if not (isinstance(item, (list, tuple)) or type(item) == contain_type):
            return False
        if isinstance(item, (tuple, list)):
            for s in item:
                if not _is_list_tuple(s):
                    return False
        return True

    if not isinstance(index, (tuple, list)):
        return False
    for s in index:
        if not _is_list_tuple(s):
            return False
    return True


def get_list_index_shape(var_dims, index_dims):
    var_dims_size = len(var_dims)
    index_dims_size = len(index_dims)

    out_dims_size = var_dims_size - index_dims[0] + index_dims_size - 1

    out_dims_shape = [1] * out_dims_size

52
    out_dims_shape[: index_dims_size - 1] = index_dims[1:]
W
WeiXin 已提交
53

54
    out_dims_shape[index_dims_size - 1 :] = var_dims[index_dims[0] :]
W
WeiXin 已提交
55 56 57 58 59 60 61
    return out_dims_shape


class SliceInfo:
    def __init__(self):
        self.pre_shape = None
        self.indexes = []
W
WeiXin 已提交
62
        self.dtype = None
W
WeiXin 已提交
63 64

    def update(self, index):
65
        if is_list_tuple(index, int) or isinstance(
66
            index, (paddle.base.Variable, np.ndarray)
67
        ):
W
WeiXin 已提交
68
            # convert index to Tensor
69
            if not isinstance(index, paddle.base.Variable):
W
WeiXin 已提交
70 71
                index = paddle.assign(index)

W
WeiXin 已提交
72 73 74 75 76
            if self.dtype is None:
                self.dtype = index.dtype
            else:
                if index.dtype != self.dtype:
                    raise IndexError(
77 78 79 80
                        "Data type of Tensor/List index should be same. The current data type is {}, but the previous data type is {}.".format(
                            index.dtype, self.dtype
                        )
                    )
W
WeiXin 已提交
81

W
WeiXin 已提交
82 83 84 85 86 87
            self.indexes.append(index)

            if self.pre_shape is None:
                self.pre_shape = index.shape
            else:
                if self.pre_shape != index.shape:
88
                    # broadcast
89 90 91
                    cur_shape = paddle.broadcast_shape(
                        self.pre_shape, index.shape
                    )
W
WeiXin 已提交
92
                    for i in range(len(self.indexes)):
93
                        self.indexes[i] = paddle.broadcast_to(
94 95
                            self.indexes[i], cur_shape
                        )
W
WeiXin 已提交
96 97 98
                self.pre_shape = self.indexes[-1].shape
        else:
            raise ValueError(
99 100 101 102
                "Index should be list/tuple of int or Tensor, but received {}.".format(
                    index
                )
            )
W
WeiXin 已提交
103 104 105 106 107 108 109 110 111

    def shape_stride(self, shape):
        s = [1] * len(shape)
        for i in range(len(shape) - 2, -1, -1):
            s[i] = shape[i + 1] * s[i + 1]

        return s

    def numel(self, shape):
112
        return reduce(lambda x, y: x * y, shape, 1)
W
WeiXin 已提交
113 114 115

    def get_offset_stride(self, tensor_shape):
        for index in self.indexes:
116
            if not isinstance(index, paddle.base.Variable):
W
WeiXin 已提交
117 118
                raise ValueError(
                    "only support list/tensor index, but received {}.".format(
119 120 121
                        type(index)
                    )
                )
W
WeiXin 已提交
122 123 124

        if len(self.indexes) <= len(tensor_shape) or len(self.indexes) == 1:
            shape = paddle.stack(self.indexes)
125 126 127
            axes = list(range(1, len(self.pre_shape) + 1)) + [
                0,
            ]
W
WeiXin 已提交
128 129 130

        else:
            raise ValueError(
131 132 133 134
                "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
                    len(tensor_shape), self.pre_shape[0]
                )
            )
W
WeiXin 已提交
135 136 137 138 139 140 141 142 143 144

        shape_transpose = paddle.transpose(shape, axes)
        return shape_transpose

    def get_item(self, tensor):
        shape_transpose = self.get_offset_stride(tensor.shape)
        index = paddle.assign(shape_transpose)
        return paddle.gather_nd(tensor, index)

    def set_item(self, tensor_origin, value):
145
        if not isinstance(value, paddle.base.Variable):
W
WeiXin 已提交
146 147 148 149
            value = paddle.assign(value)
        tensor_type = None

        if tensor_origin.dtype in [
150 151
            core.VarDesc.VarType.FP32,
            core.VarDesc.VarType.FP64,
W
WeiXin 已提交
152 153 154 155 156 157 158 159 160 161 162 163
        ]:
            tensor = tensor_origin
        else:
            tensor_type = tensor_origin.dtype
            tensor = tensor_origin.astype(core.VarDesc.VarType.FP32)

        if value.dtype != tensor.dtype:
            value = value.astype(tensor.dtype)

        shape_transpose = self.get_offset_stride(tensor_origin.shape)
        index = paddle.assign(shape_transpose)

164 165 166 167 168 169 170
        gather_tensor_shape = get_list_index_shape(
            tensor.shape,
            [
                len(self.indexes),
            ]
            + list(self.indexes[-1].shape),
        )
W
WeiXin 已提交
171

172 173 174
        value_dims_bd = [
            1,
        ] * len(gather_tensor_shape)
175
        value_dims_bd[-len(value.shape) :] = list(value.shape)
W
WeiXin 已提交
176 177

        for i in range(len(gather_tensor_shape)):
178
            if not (
179 180
                len(value_dims_bd) == 0
                or value_dims_bd[i] == gather_tensor_shape[i]
181 182 183 184 185 186 187
                or value_dims_bd[i] == 1
            ):
                raise ValueError(
                    "{} can not broadcast into {}".format(
                        value.shape, gather_tensor_shape
                    )
                )
W
WeiXin 已提交
188 189 190

        value_broadcast = paddle.broadcast_to(value, gather_tensor_shape)

191
        value_1d = value_broadcast.reshape(
192 193
            [-1] + gather_tensor_shape[len(index.shape) - 1 :]
        )
W
WeiXin 已提交
194 195 196 197

        index_1d = index.reshape([-1, index.shape[-1]])

        tensor_stride = paddle.assign(
198 199
            self.shape_stride(tensor.shape[: index.shape[-1]])
        )
W
WeiXin 已提交
200 201 202 203 204
        inds = []
        for i in range(index_1d.shape[0]):
            temp = (index_1d[i] * tensor_stride).sum()
            inds.append(temp)
        index_1d = paddle.stack(inds).reshape([-1])
205
        t_reshape = tensor.reshape([-1] + list(tensor.shape[index.shape[-1] :]))
W
WeiXin 已提交
206 207 208
        out = paddle.scatter(t_reshape, index_1d, value_1d)
        if tensor_type is not None:
            out = out.astype(tensor_type)
209 210 211
        tensor_origin = _setitem_impl_(
            tensor_origin, ..., out.reshape(tensor_origin.shape)
        )
W
WeiXin 已提交
212 213 214 215

        return tensor_origin


216 217
def replace_ellipsis(var, item):
    from .framework import Variable
218

219 220 221 222 223 224 225 226 227 228
    # Use slice(None) to replace Ellipsis.
    # For var, var.shape = [3,4,5,6]
    #
    #   var[..., 1:2] -> var[:, :, :, 1:2]
    #   var[0, ...] -> var[0]
    #   var[0, ..., 1:2] -> var[0, :, :, 1:2]

    item = list(item)

    # Remove Variable to skip bug when counting Ellipsis
W
WeiXin 已提交
229
    item_remove_var = [
230 231
        ele
        for ele in item
232
        if not isinstance(ele, (Variable, np.ndarray)) and ele is not None
W
WeiXin 已提交
233
    ]
234 235 236 237 238 239 240 241 242 243 244
    ell_count = item_remove_var.count(Ellipsis)
    if ell_count == 0:
        return item
    elif ell_count > 1:
        raise IndexError("An index can only have a single ellipsis ('...')")

    ell_idx = item.index(Ellipsis)

    if ell_idx == len(item) - 1:
        return item[:-1]
    else:
245 246 247
        item[ell_idx : ell_idx + 1] = [slice(None)] * (
            len(var.shape) - len(item) + item.count(None) + 1
        )
248 249 250 251

    return item


J
JYChen 已提交
252
def replace_ndarray_and_range(item):
W
WeiXin 已提交
253 254 255 256
    new_item = []
    for slice_item in item:
        if isinstance(slice_item, np.ndarray):
            new_item.append(paddle.assign(slice_item))
J
JYChen 已提交
257 258
        elif isinstance(slice_item, range):
            new_item.append(list(slice_item))
W
WeiXin 已提交
259 260 261 262 263
        else:
            new_item.append(slice_item)
    return new_item


264 265 266 267 268 269 270 271 272 273 274
def replace_none(item):
    new_item = []
    none_axes = []
    for i, slice_item in enumerate(item):
        if slice_item is None:
            none_axes.append(i)
        else:
            new_item.append(slice_item)
    return new_item, none_axes


275 276
def is_integer_or_scalar_tensor(ele):
    from .framework import Variable
277

278 279 280
    if isinstance(ele, int):
        return True
    elif isinstance(ele, Variable):
J
JYChen 已提交
281 282 283 284 285 286 287 288 289
        # NOTE(zoooo0820): For compatibility, if FLAGS_set_to_1d is set to True,
        # 1-D tensor is still treated as a scalar, which means basic indexing.
        # This will be removed in future.
        if paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']:
            if len(ele.shape) == 1 and ele.shape[0] == 1:
                warnings.warn(
                    "1-D Tensor will be treat as advanced indexing in future version. Currently, 1-D Tensor means a scalar, not vector, and please modify it to 0-D Tensor. If advanced indexing is needed, please use `export FLAGS_set_to_1d=False` to set the flag."
                )
                return True
290
        if len(ele.shape) == 0 and ele.dtype != paddle.bool:
291 292 293 294
            return True
    return False


295 296
def is_bool_tensor(ele):
    from .framework import Variable
297

298 299 300 301 302
    if isinstance(ele, Variable) and ele.dtype == paddle.bool:
        return True
    return False


303 304 305
def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags):
    from .framework import Variable

306 307
    if paddle.utils._contain_var(attr):
        inputs[tensor_attr_name] = paddle.utils._convert_to_tensor_list(
308 309
            attr, dtype="int64"
        )
310 311 312 313 314 315 316 317 318 319
        for i, dim in enumerate(attr):
            if isinstance(dim, Variable):
                attrs[attr_name].append(-1)
                infer_flags[i] = -1
            else:
                attrs[attr_name].append(dim)
    else:
        attrs[attr_name] = attr


320
# the item is a tensor of bool
321 322
def get_value_for_bool_tensor(var, item):
    if len(item.shape) > len(var.shape):
323 324 325 326 327
        raise IndexError(
            "The dims of bool index doesn't match indexed array, "
            "the dims of bool index except to be equal or less "
            "than {}, but received {}.".format(len(var.shape), len(item.shape))
        )
328 329 330 331 332
    i = 0
    item_shape = item.shape
    while i < len(item.shape):
        dim_len = item_shape[i]
        if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]:
333
            raise IndexError(
334 335 336 337 338
                "The dimension of bool index doesn't match indexed array along "
                "dimension {}, the target dimension is {}, but received {}.".format(
                    i, var.shape[i], dim_len
                )
            )
339 340
        i += 1
    empty_shape = [0] + list(var.shape[i:])
341 342 343 344

    def idx_not_empty(var, item):
        from ..tensor import gather_nd

345
        bool_2_idx = paddle.nonzero(item == True)
346 347
        return gather_nd(var, bool_2_idx)

348
    from paddle.static.nn import cond
349 350

    return cond(
351 352 353
        item.any(),
        lambda: idx_not_empty(var, item),
        lambda: paddle.empty(empty_shape, var.dtype),
354
    )
355 356


357
def _setitem_for_tensor_array(var, item, value):
358 359 360 361 362 363 364
    """branches for tensor array setitem operation.
    A item can be a:
    (1) int/Variable, which is a simple number/variable such as [1], [-2]
    (2) Slice, which is represented by bounds such as [2:-1]
    (3) Tuple, which includes the above two cases such as [2:-1, 1]
    If item is case (1), we perform paddle.tensor.array_write,
    in other cases, we raise a NotImplementedError.
365
    """
366

367
    from .framework import Variable
368 369

    assert (
370
        not paddle.in_dynamic_mode()
371 372
    ), "setitem for tensor_array must be called in static graph mode."
    if isinstance(item, (Variable, int)):
373
        from paddle.jit.dy2static.variable_trans_func import (
374 375
            to_static_variable,
        )
376 377
        from paddle import cast
        from paddle.tensor import array_write
378

379 380
        item = paddle.cast(to_static_variable(item), dtype='int64')
        value = to_static_variable(value)
381
        return array_write(x=value, i=item, array=var)
382 383
    else:
        raise NotImplementedError(
384 385 386 387
            "Only support __setitem__ by Int/Variable in tensor_array, but gets {}".format(
                type(item)
            )
        )
388 389


390
def _setitem_impl_(var, item, value):
391
    from paddle.base import core
392
    from .framework import default_main_program, Variable
393

394 395
    if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        return _setitem_for_tensor_array(var, item, value)
396 397

    inputs = {'Input': var}
398

399 400
    # 1. Parse item
    if not isinstance(item, tuple):
401
        item = (item,)
402 403 404 405 406 407 408

    decrease_axes = []
    axes = []
    starts = []
    ends = []
    steps = []

J
JYChen 已提交
409
    item = replace_ndarray_and_range(item)
410
    item = replace_ellipsis(var, item)
411
    item, none_axes = replace_none(item)
W
WeiXin 已提交
412
    slice_info = SliceInfo()
Z
zyfncg 已提交
413 414
    dim = 0
    for _, slice_item in enumerate(item):
415 416 417
        if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor(
            slice_item
        ):
418 419 420 421 422 423 424 425 426 427 428
            decrease_axes.append(dim)
            start = slice_item
            end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
            step = 1

        elif isinstance(slice_item, slice):
            start = slice_item.start
            end = slice_item.stop
            step = slice_item.step

            if start is None and end is None and step is None:
Z
zyfncg 已提交
429
                dim += 1
430 431 432 433 434 435 436
                continue

            step = 1 if step is None else step

            if not isinstance(step, Variable) and step == 0:
                raise ValueError(
                    "When assign a value to a paddle.Tensor, step can not be 0, "
437 438
                    "but received step is {}.".format(step)
                )
439 440 441 442 443 444 445 446 447 448 449 450

            if isinstance(step, Variable) and (start is None or end is None):
                raise ValueError(
                    "When assign a value to a paddle.Tensor, it's not supported that "
                    "the start or end is None when the type of step is paddle.Tensor."
                )

            if start is None:
                start = 0 if step > 0 else MAX_INTEGER

            if end is None:
                end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER)
Z
zyfncg 已提交
451 452 453 454 455 456 457
        elif isinstance(slice_item, list):
            if is_list_tuple(slice_item, int):
                slice_info.update(slice_item)
                continue

            for i in slice_item:
                if not isinstance(i, bool):
458 459 460
                    raise TypeError(
                        "Doesn't support {} in index list.".format(type(i))
                    )
Z
zyfncg 已提交
461 462 463

            if len(item) != 1:
                raise IndexError(
464 465 466 467
                    "When index contains a bool list, its length must be 1, but received {}.".format(
                        len(item)
                    )
                )
Z
zyfncg 已提交
468

469
            idx_tensor = paddle.assign(slice_item)
Z
zyfncg 已提交
470 471 472 473 474 475
            return set_value_for_bool_tensor(var, idx_tensor, value)

        elif isinstance(slice_item, Variable):
            if slice_item.dtype == core.VarDesc.VarType.BOOL:
                if len(item) != 1:
                    raise IndexError(
476 477 478 479
                        "When index contains a bool tensor, its length must be 1, but received {}.".format(
                            len(item)
                        )
                    )
Z
zyfncg 已提交
480 481 482 483
                return set_value_for_bool_tensor(var, slice_item, value)
            else:
                slice_info.update(slice_item)
                continue
484 485
        else:
            raise IndexError(
Z
zyfncg 已提交
486
                "Valid index accept int, slice, ellipsis, None, list of bool, Variable, "
487 488
                "but received {}.".format(slice_item)
            )
489 490 491 492 493 494

        axes.append(dim)
        starts.append(start)
        ends.append(end)
        steps.append(step)

Z
zyfncg 已提交
495
        dim += 1
W
WeiXin 已提交
496 497 498
    if slice_info.indexes:
        if len(slice_info.indexes) != len(item):
            raise IndexError(
499 500 501 502
                "Valid index accept int or slice or ellipsis or list, but received {}.".format(
                    item
                )
            )
W
WeiXin 已提交
503
        return slice_info.set_item(var, value)
504 505 506 507 508
    attrs = {
        'axes': axes,
        'starts': starts,
        'ends': ends,
        'steps': steps,
Z
zyfncg 已提交
509
        'decrease_axes': decrease_axes,
510
        'none_axes': none_axes,
511 512
    }

513 514 515 516
    if paddle.utils._contain_var(starts):
        inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list(
            starts
        )
517
        del attrs['starts']
518 519
    if paddle.utils._contain_var(ends):
        inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends)
520
        del attrs['ends']
521 522
    if paddle.utils._contain_var(steps):
        inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps)
523 524 525 526 527 528 529
        del attrs['steps']

    # 2. Parse value
    dtype = var.dtype
    attrs['dtype'] = dtype

    from .data_feeder import convert_dtype
530

531 532
    #  2.1 value is an integer, float or complex
    if isinstance(value, (bool, int, float, complex)):
533 534 535 536 537
        value = np.array([value]).astype(convert_dtype(dtype))

    #  2.2 value is a np.ndarray
    if isinstance(value, np.ndarray):
        shape = list(value.shape)
538 539
        values = value.ravel().tolist()
        attrs["values"] = values
540 541
        attrs["shape"] = shape

W
wanghuancoder 已提交
542
    elif isinstance(value, (Variable, core.eager.Tensor)):
543 544 545 546 547
        inputs["ValueTensor"] = value
    else:
        raise TypeError(
            "Only support to assign an integer, float, numpy.ndarray or "
            "paddle.Tensor to a paddle.Tensor, but received {}".format(
548 549 550
                type(value)
            )
        )
551

552
    if paddle.in_dynamic_mode():
Z
zyfncg 已提交
553
        var._bump_inplace_version()
554 555
        output = var
    else:
556
        helper = paddle.base.layer_helper.LayerHelper('set_value', **locals())
557 558 559 560 561 562 563
        if helper.main_program.current_block_idx != 0:
            # not in global block, we should create a global variable.
            output = helper._create_global_variable_for_type_inference(
                dtype=var.dtype
            )
        else:
            output = helper.create_variable_for_type_inference(dtype=var.dtype)
Z
zyfncg 已提交
564

565
    cur_block = default_main_program().current_block()
566 567 568
    cur_block.append_op(
        type="set_value",
        inputs=inputs,
569
        outputs={'Out': output},
570 571 572
        attrs=attrs,
        inplace_map={"Input": "Out"},
    )
573

574 575 576 577 578 579
    if not paddle.in_dynamic_mode():
        # map var to the new output
        from paddle.jit.dy2static.program_translator import (
            ProgramTranslator,
        )

X
xiongkun 已提交
580
        ProgramTranslator.get_instance()._inplace_map.add(
581 582 583 584
            cur_block.program, var.desc.id(), output
        )

    return output
Z
zyfncg 已提交
585 586


587
# the item is a tensor of bool
Z
zyfncg 已提交
588 589
def set_value_for_bool_tensor(var, item, value):
    if len(item.shape) > len(var.shape):
590 591 592 593 594
        raise IndexError(
            "The dims of bool index doesn't match indexed array, "
            "the dims of bool index except to be equal or less "
            "than {}, but received {}.".format(len(var.shape), len(item.shape))
        )
Z
zyfncg 已提交
595
    for i, dim_len in enumerate(item.shape):
596
        if dim_len != -1 and var.shape[i] != -1 and dim_len != var.shape[i]:
Z
zyfncg 已提交
597 598
            raise IndexError(
                "The dimension of bool index doesn't match indexed array along "
599 600 601 602
                "dimension {}, the target dimension is {}, but received {}.".format(
                    i, var.shape[i], dim_len
                )
            )
Z
zyfncg 已提交
603 604 605 606 607 608

    def idx_not_empty(var, item, value):
        from .framework import Variable
        from ..tensor import gather_nd, scatter_nd_add

        if not isinstance(value, Variable):
609
            value = paddle.assign(value).cast(var.dtype)
Z
zyfncg 已提交
610

611
        idx = paddle.nonzero(item)
Z
zyfncg 已提交
612 613 614
        gather_val = gather_nd(var, idx)
        gather_val_new = value - gather_val
        out = scatter_nd_add(var, idx, gather_val_new)
615 616 617 618 619
        var = _setitem_impl_(var, ..., out)
        return var

    def idx_is_empty(var):
        return var
Z
zyfncg 已提交
620

621
    from paddle.static.nn import cond
622

Z
zyfncg 已提交
623
    # If all the bool index is False, just do nothing
624 625 626 627 628
    var = cond(
        item.any(),
        lambda: idx_not_empty(var, item, value),
        lambda: idx_is_empty(var),
    )
Z
zyfncg 已提交
629 630

    return var
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695


def deal_advanced_index(ori_tensor, indices, is_for_setitem):
    """
    Transpose origin Tensor and advanced indices to the front.

    Returns:
        transed_tensor (Tensor): transposed tensor, corresbonding with advanced indices
        transed_index (List): advanced indices transed to the front
        trans_back_dim (List): order of axes to transpose back to original order. Only used in __setitem__.
        pos_of_new_dim (int):  axis of new dim in the result. Only used in __getitem__.
        rank_of_new_dim (int): rank of new dim in the result. Only used in __getitem__.
    """
    transed_dim = []
    transed_index = []

    # These flags indicates whether the result get by gather_nd requires a second transpose.
    # Only used in __getitem__.
    pos_of_new_dim = MAX_INTEGER
    rank_of_new_dim = 1

    for i, indice in enumerate(indices):
        if indice is not None:
            if not is_for_setitem:
                if i == 0:
                    # case 1: advanced indices at axis 0, the new dim will be at first.
                    pos_of_new_dim = 0
                if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1:
                    # case 2: there are not adjacent advanced indices, the new dim will be at first.
                    pos_of_new_dim = 0
                else:
                    pos_of_new_dim = min(pos_of_new_dim, i)
                rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim)
            transed_dim.append(i)
            transed_index.append(indice[1])
    for i in range(ori_tensor.ndim):
        if indices[i] is None:
            transed_dim.append(i)
    transed_tensor = ori_tensor.transpose(transed_dim)

    trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else []

    return (
        transed_tensor,
        transed_index,
        trans_back_dim,
        pos_of_new_dim,
        rank_of_new_dim,
    )


def parse_index(x, indices):
    advanced_index = [None] * 2 * len(x.shape)  # content is (dim, index)
    # for set_value / slice / strided_slice OP
    decrease_axes = []
    axes = []
    starts = []
    ends = []
    steps = []
    use_strided_slice = False
    has_advanced_index = False

    if not isinstance(indices, tuple):
        indices = (indices,)

J
JYChen 已提交
696
    indices = replace_ndarray_and_range(indices)
697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773
    indices = replace_ellipsis(x, indices)
    indices, none_axes = replace_none(indices)

    is_tensor_array = (
        hasattr(x, "desc")
        and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
    )

    estimated_dim = 0
    for dim, slice_item in enumerate(indices):
        start, end, step = None, None, None
        if is_integer_or_scalar_tensor(slice_item):
            if (
                not is_tensor_array
                and isinstance(slice_item, int)
                and x.shape[dim] is not None
                and x.shape[dim] >= 0
                and slice_item >= x.shape[dim]
            ):
                # For python, if users write a, b = var, the __getitem__
                # method will iterate through 0, 1, 2 ... until __getitem__
                # throws an IndexError, then stop. The var[0], var[1] will
                # be given to a, b respectively. If more values are given,
                # the unpack size would cause error.
                # We raises IndexError here to support grammar like `a, b = var`
                raise IndexError(
                    "slice_item %d at dim %d should be >= 0 and < x.shape[%d]: %d"
                    % (slice_item, dim, dim, x.shape[dim])
                )
            # not calculate result to reduce call times for slice OP.
            decrease_axes.append(dim)
            start = slice_item
            step = 1
            end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
        elif isinstance(slice_item, bool):
            # single bool is advanced-indexing
            none_axes.append(dim)
            estimated_dim += 1
            advanced_index[estimated_dim] = (
                estimated_dim,
                paddle.to_tensor(slice_item),
            )
            has_advanced_index = True
        elif isinstance(slice_item, slice):
            start = slice_item.start
            end = slice_item.stop
            step = slice_item.step
            estimated_dim += 1

            if start is None and end is None and step is None:
                continue

            step = 1 if step is None else step
            if start is None:
                start = 0 if step > 0 else MAX_INTEGER
            if end is None:
                end = MAX_INTEGER if step > 0 else -1

        elif isinstance(slice_item, (list, tuple)):
            advanced_index[estimated_dim] = (
                estimated_dim,
                paddle.to_tensor(slice_item),
            )

            if (
                advanced_index[estimated_dim][1].dtype == paddle.bool
                and len(slice_item) != x.shape[dim]
            ):
                raise IndexError(
                    "The shape of boolean index {} did not match indexed tensor {} along axis {}".format(
                        len(slice_item), x.shape[dim], dim
                    )
                )

            has_advanced_index = True
            estimated_dim += 1

774
        elif isinstance(slice_item, paddle.base.Variable):
775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803
            # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing.
            if slice_item.dtype == paddle.bool:
                if slice_item.ndim == 0:
                    # 0-D bool Tensor, same as single PY-bool.
                    none_axes.append(dim)

                elif slice_item.shape[0] != x.shape[dim]:
                    raise IndexError(
                        "The shape of boolean index {} did not match indexed tensor {} along axis {}".format(
                            slice_item.shape[0], x.shape[dim], dim
                        )
                    )
            advanced_index[estimated_dim] = (estimated_dim, slice_item)
            has_advanced_index = True
            estimated_dim += 1

        else:
            raise IndexError(
                "Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format(
                    slice_item
                )
            )
        if not (start is None or end is None or step is None):
            starts.append(start)
            ends.append(end)
            steps.append(step)
            axes.append(dim)
            use_strided_slice = (
                True
804
                if (isinstance(step, paddle.base.Variable) or step != 1)
805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831
                else use_strided_slice
            )
    return (
        starts,
        ends,
        steps,
        axes,
        none_axes,
        decrease_axes,
        advanced_index,
        has_advanced_index,
        use_strided_slice,
    )


def _setitem_static(x, indices, values):
    """
    In dynamic mode, this function will modify the value at input tensor, returning same Tensor as input.
    But it will return a new Tensor with assigned value in static mode.

    Args:
        x(Tensor): Tensor to be set value.
        indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched.
        values(Tensor|Number|Ndarray): values to be assigned to the x.
    """
    from .framework import default_main_program, Variable

832
    if x.type == paddle.base.core.VarDesc.VarType.LOD_TENSOR_ARRAY:
833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899
        return _setitem_for_tensor_array(x, indices, values)

    # step1: parsing the index and recording them
    (
        starts,
        ends,
        steps,
        axes,
        none_axes,
        decrease_axes,
        advanced_index,
        has_advanced_index,
        use_strided_slice,
    ) = parse_index(x, indices)

    inputs = {'Input': x}
    attrs = {
        'axes': axes,
        'starts': starts,
        'ends': ends,
        'steps': steps,
        'decrease_axes': decrease_axes,
        'none_axes': none_axes,
    }
    if paddle.utils._contain_var(starts):
        inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list(
            starts
        )
        del attrs['starts']
    if paddle.utils._contain_var(ends):
        inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends)
        del attrs['ends']
    if paddle.utils._contain_var(steps):
        inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps)
        del attrs['steps']

    if not has_advanced_index:
        # step2. Parse values
        dtype = x.dtype
        attrs['dtype'] = dtype

        from .data_feeder import convert_dtype

        if isinstance(values, (bool, int, float, complex)):
            values = np.array([values]).astype(convert_dtype(dtype))

        if isinstance(values, np.ndarray):
            shape = list(values.shape)
            values = values.ravel().tolist()
            attrs["values"] = values
            attrs["shape"] = shape

        elif isinstance(values, Variable):
            inputs["ValueTensor"] = values
        else:
            raise TypeError(
                "Only support to assign an integer, float, numpy.ndarray or "
                "paddle.Tensor to a paddle.Tensor, but received {}".format(
                    type(values)
                )
            )

        # step3.1: Only basic indexing, use OP set_value to set value.
        if paddle.in_dynamic_mode():
            x._bump_inplace_version()
            output = x
        else:
900
            helper = paddle.base.layer_helper.LayerHelper(
901 902
                'set_value', **locals()
            )
903 904 905 906 907 908 909 910 911
            if helper.main_program.current_block_idx != 0:
                # not in global block, we should create a global variable.
                output = helper._create_global_variable_for_type_inference(
                    dtype=x.dtype
                )
            else:
                output = helper.create_variable_for_type_inference(
                    dtype=x.dtype
                )
912 913 914 915 916 917 918 919 920 921 922
        cur_block = default_main_program().current_block()
        cur_block.append_op(
            type="set_value",
            inputs=inputs,
            outputs={'Out': output},
            attrs=attrs,
            inplace_map={"Input": "Out"},
        )

        if not paddle.in_dynamic_mode():
            # map var to the new output
X
xiongkun 已提交
923
            paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973
                cur_block.program, x.desc.id(), output
            )
        return output
    else:
        # step3.2: Case for there are advanced indexing.
        #   1. get __getitem__ result of basic indexing;
        #   2. transpose original tensor so that the axis with advanced indexing will come to the first;
        #   3. assign values to the sliced result by index_put OP;
        #   4. transpose back and assign the result to original tensor by set_value OP.

        sub_tensor = get_tensor_with_basic_indexing(
            x,
            axes,
            starts,
            ends,
            steps,
            decrease_axes,
            none_axes,
            use_strided_slice,
        )
        (
            transed_sub_tensor,
            adjusted_advanced_index,
            transback_dim,
            _,
            _,
        ) = deal_advanced_index(sub_tensor, advanced_index, True)
        if not isinstance(values, Variable):
            values = paddle.assign(values).astype(transed_sub_tensor.dtype)
        transed_sub_tensor = transed_sub_tensor.index_put(
            adjusted_advanced_index, values
        )

        # NOTE(zoooo0820): now basic indexing of __getitem__ will return a new Tensor both in dynamic and static mode
        # After strided is ready and basic indexing returns view of Tensor in dynamic mode. The code shoule be changed
        # for dynamic mode.
        if paddle.in_dynamic_mode():
            transed_sub_tensor.index_put_(adjusted_advanced_index, values)
        else:
            transed_sub_tensor = transed_sub_tensor.index_put(
                adjusted_advanced_index, values
            )

        transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)

        inputs["ValueTensor"] = transback_sub_tensor
        if paddle.in_dynamic_mode():
            x._bump_inplace_version()
            output = x
        else:
974
            helper = paddle.base.layer_helper.LayerHelper(
975 976
                'set_value', **locals()
            )
977 978 979 980 981 982 983 984 985
            if helper.main_program.current_block_idx != 0:
                # not in global block, we should create a global variable.
                output = helper._create_global_variable_for_type_inference(
                    dtype=x.dtype
                )
            else:
                output = helper.create_variable_for_type_inference(
                    dtype=x.dtype
                )
986 987 988 989 990 991 992 993 994 995
        cur_block = default_main_program().current_block()
        cur_block.append_op(
            type="set_value",
            inputs=inputs,
            outputs={'Out': output},
            attrs=attrs,
            inplace_map={"Input": "Out"},
        )
        if not paddle.in_dynamic_mode():
            # map var to the new output
X
xiongkun 已提交
996
            paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
997 998 999 1000 1001 1002 1003 1004
                cur_block.program, x.desc.id(), output
            )
        return output


def get_tensor_with_basic_indexing(
    x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice
):
1005
    from .dygraph.base import in_to_static_mode
1006

1007
    if in_to_static_mode() and hasattr(x, "is_view_var"):
1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098
        x.is_view_var = True

    if len(axes) == 0:
        out = x
    else:
        op_type = "strided_slice" if use_strided_slice else "slice"
        inputs = {'Input': [x]}
        attrs = {
            'axes': axes,
            'starts': [],
            'ends': [],
            'decrease_axis': decrease_axes,
        }
        if use_strided_slice:
            attrs['strides'] = []
        infer_flags = [1] * len(axes)
        deal_attrs(
            attrs, starts, "starts", "StartsTensorList", inputs, infer_flags
        )
        deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags)
        deal_attrs(
            attrs, steps, "strides", "StridesTensorList", inputs, infer_flags
        )
        attrs['infer_flags'] = infer_flags

        if paddle.in_dynamic_mode():
            if "StartsTensorList" in inputs.keys():
                st = inputs['StartsTensorList']
            else:
                st = attrs['starts']
            if "EndsTensorList" in inputs.keys():
                end = inputs['EndsTensorList']
            else:
                end = attrs['ends']
            if "StridesTensorList" in inputs.keys():
                stride = inputs['StridesTensorList']
            else:
                stride = attrs['strides']
            if use_strided_slice:
                out = paddle._C_ops.strided_slice(x, axes, st, end, stride)
                if len(decrease_axes) > 0:
                    out = paddle._C_ops.squeeze(out, decrease_axes)
            else:
                out = paddle._C_ops.slice(
                    x,
                    axes,
                    st,
                    end,
                    attrs['infer_flags'],
                    attrs['decrease_axis'],
                )
        else:
            from .framework import default_main_program

            target_block = default_main_program().current_block()

            slice_out_var = target_block.create_var(
                name=unique_name.generate_with_ignorable_key(
                    x.name + "_" + op_type
                ),
                dtype=x.dtype,
            )
            target_block.append_op(
                type=op_type,
                inputs=inputs,
                outputs={'Out': [slice_out_var]},
                attrs=attrs,
            )
            out = slice_out_var
    # NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
    # with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
    # otherwise the output shape will be not correct.
    set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']
    if set_to_1d and len(decrease_axes) == len(x.shape):
        warnings.warn(
            "Warning: In Tensor '__getitem__', if the number of scalar elements in the index is equal to the rank of the Tensor, the output should be 0-D. In order to be consistent with the behavior of previous versions, it will be processed to 1-D. But it is not correct and will be removed in release 2.6. If 1-D is still wanted, please modify the index element from scalar to slice (e.g. 'x[i]' => 'x[i:i+1]')."
        )
        none_axes = none_axes[1:]

    if len(none_axes) > 0:
        # Deal with cases that decrease_axes is not empty
        # For example:
        # # x.shape: (2,3,4)
        # out = x[0, 0:2, None] # out.shape : (2, 1, 4)
        for idx, axis in enumerate(none_axes):
            l = len([i for i in decrease_axes if i < axis])
            new_axis = axis - l
            none_axes[idx] = new_axis

        out = paddle.unsqueeze(out, axis=none_axes)

1099
    if in_to_static_mode() and hasattr(out, "is_view_var"):
1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182
        out.is_view_var = True
    return out


def _getitem_static(x, indices):
    """
    Args:
        x(Tensor): Tensor to be indexing.
        indices(int|slice|None|Tensor|List|Tuple...): Indices, used to indicate the position of the element to be fetched.
    """
    # step1: parsing the index and recording them
    (
        starts,
        ends,
        steps,
        axes,
        none_axes,
        decrease_axes,
        advanced_index,
        has_advanced_index,
        use_strided_slice,
    ) = parse_index(x, indices)

    # step2: Dealing with basic indexing
    out = get_tensor_with_basic_indexing(
        x,
        axes,
        starts,
        ends,
        steps,
        decrease_axes,
        none_axes,
        use_strided_slice,
    )

    # step3: Dealing with advanced indexing
    if has_advanced_index:
        (
            transed_tensor,
            adjusted_advanced_index,
            _,
            pos_of_new_dim,
            rank_of_new_dim,
        ) = deal_advanced_index(out, advanced_index, False)

        # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently
        if (
            len(adjusted_advanced_index) == 1
            and adjusted_advanced_index[0].dtype == paddle.bool
        ):
            # Note: now slice not support 0-size Tensor, so only one bool tensor can return empty 0-size.
            out = get_value_for_bool_tensor(
                transed_tensor, adjusted_advanced_index[0]
            )
        else:
            adjusted_advanced_index = parse_bool_and_broadcast_indices(
                adjusted_advanced_index
            )
            advanced_index_tensor = paddle.stack(
                adjusted_advanced_index, axis=-1
            )
            out = paddle.gather_nd(transed_tensor, advanced_index_tensor)

        if pos_of_new_dim != 0:
            perm = (
                list(range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim))
                + list(range(0, pos_of_new_dim))
                + list(range(pos_of_new_dim + rank_of_new_dim, out.ndim))
            )
            out = out.transpose(perm)

    return out


def parse_bool_and_broadcast_indices(indices):
    # deal with multiple Tensors and translating bool tensor to int tensor.
    # In static mode, bool-tensor cannot be broadcasted since its corressponding int tensor's shape cannot be infered.
    for i, indice in enumerate(indices):
        if indice.dtype == paddle.bool:
            indices[i] = paddle.nonzero(indice)[:, 0]
    if len(indices) > 1:
        indices = paddle.broadcast_tensors(indices)
    return indices