From 2b44ae5de9b4e53b2fd7ba992d353025dfc40b5c Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Tue, 15 Jun 2021 15:39:03 +0800 Subject: [PATCH] [cherry-pick] Polish code for setitem/getitem and support index for list/Tensor/None/Ellipsis/bool (#33528) * [cherry-pick 2.1] Polish code for _getitem_impl_ (#32868) * [cherry-pick] Polish code for setitem and getitem (#32911) * [slice getitem] Support getitem idx is Tensor or List (#33000) * [getitem] Support index is None for getitem in static mode (#33001) * [Static getitem] Support static Variable getitem for Ellipsis index (#32876) * [static getitem]Support index is list bool for getitem in static mode (#33298) --- python/paddle/fluid/framework.py | 355 +--------------- .../fluid/tests/unittests/test_variable.py | 181 +++++++- python/paddle/fluid/variable_index.py | 390 ++++++++++++++++++ 3 files changed, 569 insertions(+), 357 deletions(-) create mode 100644 python/paddle/fluid/variable_index.py diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index bc8a06cb1ed..03e7833aca1 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -39,6 +39,7 @@ from . import unique_name import paddle.version as fluid_version import warnings import functools +from .variable_index import _getitem_impl_, _setitem_impl_ __all__ = [ 'Program', @@ -794,205 +795,6 @@ class ParameterMetaClass(VariableMetaClass): return issubclass(t, Parameter) -def _getitem_impl_(var, item): - """ - Slice the variable. - - Args: - item(int/slice/tuple) : the index. - - Returns: - Sliced variable - """ - - if not isinstance(item, tuple): - item = [item] - - decrease_axis = [] - slice_axis = [] - slice_start = [] - slice_end = [] - slice_step = [] - use_strided_slice = False - reverse_axis = [] - target_block = default_main_program().current_block() - - def fill_constant(shape, value, force_cpu=False, out=None): - var.block.append_op( - type='fill_constant', - inputs={}, - outputs={'Out': [out]}, - attrs={ - 'shape': shape, - 'dtype': out.dtype, - 'value': float(value), - 'force_cpu': force_cpu - }) - out.stop_gradient = True - return out - - for dim, slice_item in enumerate(item): - if isinstance(slice_item, slice): - start = slice_item.start - end = slice_item.stop - step = slice_item.step - - if start is None and end is None and step is None: - continue - - if step is None: - step = 1 - - if start is None and end is None: - assert (step == -1) - reverse_axis.append(dim) - continue - - if start is None: - start = 0 - - if end is None: - end = 10000000 - - if step != 1: - use_strided_slice = True - - slice_axis.append(dim) - slice_start.append(start) - slice_end.append(end) - slice_step.append(step) - else: - decrease_axis.append(dim) - slice_axis.append(dim) - slice_start.append(slice_item) - slice_step.append(1) - if isinstance(slice_item, Variable): - temp_1 = var.block.create_var(dtype=slice_item.dtype) - fill_constant([1], 1, force_cpu=True, out=temp_1) - temp_end = target_block.create_var(dtype=slice_item.dtype) - target_block.append_op( - type='elementwise_add', - inputs={'X': slice_item, - 'Y': temp_1}, - outputs={'Out': temp_end}, - attrs={'axis': -1}) - slice_end.append(temp_end) - else: - slice_end.append(slice_item + 1 - if slice_item != -1 else 10000000) - - def contain_var(one_list): - for ele in one_list: - if isinstance(ele, Variable): - return True - return False - - def get_new_list_tensor(old_list): - new_list_tensor = [] - for dim in old_list: - if isinstance(dim, Variable): - dim.stop_gradient = True - new_list_tensor.append(dim) - else: - assert (isinstance(dim, int)) - temp_out = var.block.create_var(dtype='int64') - fill_constant([1], dim, force_cpu=True, out=temp_out) - new_list_tensor.append(temp_out) - return new_list_tensor - - inputs = {'Input': [var]} - attrs = { - 'axes': slice_axis, - 'starts': [], - 'ends': [], - 'decrease_axis': decrease_axis - } - if (use_strided_slice == True): - attrs['strides'] = [] - infer_flags = list(1 for i in range(len(slice_axis))) - - # starts - if contain_var(slice_start): - inputs['StartsTensorList'] = get_new_list_tensor(slice_start) - for i, dim in enumerate(slice_start): - if isinstance(dim, Variable): - attrs['starts'].append(-1) - infer_flags[i] = -1 - else: - attrs['starts'].append(dim) - else: - attrs['starts'] = slice_start - - # ends - if contain_var(slice_end): - inputs['EndsTensorList'] = get_new_list_tensor(slice_end) - for i, dim in enumerate(slice_end): - if isinstance(dim, Variable): - attrs['ends'].append(-1) - infer_flags[i] = -1 - else: - attrs['ends'].append(dim) - else: - attrs['ends'] = slice_end - - # strides - if use_strided_slice == True: - if contain_var(slice_step): - inputs['StridesTensorList'] = get_new_list_tensor(slice_step) - for i, dim in enumerate(slice_step): - if isinstance(dim, Variable): - attrs['strides'].append(-1) - infer_flags[i] = -1 - else: - attrs['strides'].append(dim) - else: - attrs['strides'] = slice_step - # infer_flags - attrs['infer_flags'] = infer_flags - - out = var - if use_strided_slice == False and len(slice_axis) > 0: - # append slice_op here - slice_out_var = target_block.create_var( - name=unique_name.generate_with_ignorable_key(var.name + "_slice"), - dtype=var.dtype) - - target_block.append_op( - type="slice", - inputs=inputs, - outputs={'Out': [slice_out_var]}, - attrs=attrs) - - out = slice_out_var - elif use_strided_slice == True and len(slice_axis) > 0: - strided_slice_out_var = target_block.create_var( - name=unique_name.generate_with_ignorable_key(var.name + - "_strided_slice"), - dtype=var.dtype) - target_block.append_op( - type="strided_slice", - inputs=inputs, - outputs={'Out': [strided_slice_out_var]}, - attrs=attrs) - - out = strided_slice_out_var - - if len(reverse_axis) > 0: - reverse_out_var = target_block.create_var( - name=unique_name.generate_with_ignorable_key(var.name + - "_slice_reverse"), - dtype=var.dtype) - target_block.append_op( - type="reverse", - inputs={'X': out}, - outputs={'Out': [reverse_out_var]}, - attrs={'axis': reverse_axis}) - - out = reverse_out_var - - return out - - @six.add_metaclass(VariableMetaClass) class Variable(object): """ @@ -1848,160 +1650,7 @@ class Variable(object): return _getitem_impl_(self, item) def __setitem__(self, item, value): - inputs = {'Input': self} - - # 1. Parse item - if not isinstance(item, tuple): - item = [item] - - decrease_axes = [] - axes = [] - starts = [] - ends = [] - steps = [] - - max_integer = sys.maxsize - - def replace_ellipsis(item): - # Use slice(None) to replace Ellipsis. - # For var, var.shape = [3,4,5,6] - # - # var[..., 1:2] -> var[:, :, :, 1:2] - # var[0, ...] -> var[0] - # var[0, ..., 1:2] -> var[0, :, :, 1:2] - - item = list(item) - - # Remove Variable to skip bug when counting Ellipsis - item_remove_var = [ - ele for ele in item if not isinstance(ele, Variable) - ] - ell_count = item_remove_var.count(Ellipsis) - if ell_count == 0: - return item - elif ell_count > 1: - raise IndexError( - "An index can only have a single ellipsis ('...')") - - ell_idx = item.index(Ellipsis) - - if ell_idx == len(item) - 1: - return item[:-1] - else: - item[ell_idx:ell_idx + 1] = [slice(None)] * ( - len(self.shape) - len(item) + 1) - - return item - - item = replace_ellipsis(item) - - for dim, slice_item in enumerate(item): - if isinstance(slice_item, slice): - start = slice_item.start - end = slice_item.stop - step = slice_item.step - - if start is None and end is None and step is None: - continue - - step = 1 if step is None else step - - # TODO: support cases when step < 1 - if not isinstance(step, Variable) and step == 0: - raise ValueError( - "When assign a value to a paddle.Tensor, step can not be 0, " - "but received step is {}.".format(step)) - - if isinstance(step, Variable) and (start is None or - end is None): - raise ValueError( - "When assign a value to a paddle.Tensor, it's not supported that " - "the start or end is None when the type of step is paddle.Tensor." - ) - - if start is None: - start = 0 if step > 0 else max_integer - - if end is None: - end = max_integer if step > 0 else (0 - max_integer) - else: - decrease_axes.append(dim) - start = slice_item - end = slice_item + 1 if slice_item != -1 else max_integer - step = 1 - - axes.append(dim) - starts.append(start) - ends.append(end) - steps.append(step) - - attrs = { - 'axes': axes, - 'starts': starts, - 'ends': ends, - 'steps': steps, - 'decrease_axes': decrease_axes - } - - from .layers import utils - if utils._contain_var(starts): - inputs['StartsTensorList'] = utils._convert_to_tensor_list(starts) - del attrs['starts'] - if utils._contain_var(ends): - inputs['EndsTensorList'] = utils._convert_to_tensor_list(ends) - del attrs['ends'] - if utils._contain_var(steps): - inputs['StepsTensorList'] = utils._convert_to_tensor_list(steps) - del attrs['steps'] - - # 2. Parse value - dtype = self.dtype - attrs['dtype'] = dtype - - from .data_feeder import convert_dtype - # 2.1 value is an integer of float - if isinstance(value, (int, float)): - value = np.array([value]).astype(convert_dtype(dtype)) - - # 2.2 value is a np.ndarray - if isinstance(value, np.ndarray): - shape = list(value.shape) - if dtype == core.VarDesc.VarType.BOOL: - value_name = "bool_values" - values = [bool(v) for v in value.flat] - elif dtype == core.VarDesc.VarType.FP32: - value_name = "fp32_values" - values = [float(v) for v in value.flat] - elif dtype == core.VarDesc.VarType.FP64: - value_name = "fp64_values" - values = [float(v) for v in value.flat] - elif dtype == core.VarDesc.VarType.INT32: - value_name = "int32_values" - values = [int(v) for v in value.flat] - elif dtype == core.VarDesc.VarType.INT64: - value_name = "int64_values" - values = [int(v) for v in value.flat] - else: - raise TypeError( - "When assign a numpy.ndarray, integer or float to a paddle.Tensor, " - "the data type of the paddle.Tensor must be bool, float32, int32 or int64, but " - "received %s." % convert_dtype(dtype)) - attrs[value_name] = values - attrs["shape"] = shape - - elif isinstance(value, Variable): - inputs["ValueTensor"] = value - else: - raise TypeError( - "Only support to assign an integer, float, numpy.ndarray or " - "paddle.Tensor to a paddle.Tensor, but received {}".format( - type(value))) - - cur_block = default_main_program().current_block() - cur_block.append_op( - type="set_value", inputs=inputs, outputs={'Out': self}, attrs=attrs) - - return self + return _setitem_impl_(self, item, value) def get_value(self, scope=None): """ diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 690ac46e563..c1956545f55 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -15,12 +15,16 @@ from __future__ import print_function import unittest +import paddle from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core import numpy as np +paddle.enable_static() + class TestVariable(unittest.TestCase): def test_np_dtype_convert(self): @@ -161,12 +165,125 @@ class TestVariable(unittest.TestCase): self.assertTrue( np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) - def test_slice(self): - place = fluid.CPUPlace() - self._test_slice(place) + def _test_slice_index_tensor(self, place): + data = np.random.rand(2, 3).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + idx0 = [1, 0] + idx1 = [0, 1] + idx2 = [0, 0] + idx3 = [1, 1] + + out0 = x[paddle.assign(np.array(idx0))] + out1 = x[paddle.assign(np.array(idx1))] + out2 = x[paddle.assign(np.array(idx2))] + out3 = x[paddle.assign(np.array(idx3))] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out0, out1, out2, out3]) + + expected = [data[idx0], data[idx1], data[idx2], data[idx3]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + + with self.assertRaises(IndexError): + one = paddle.ones(shape=[1]) + res = x[one, [0, 0]] + + def _test_slice_index_list(self, place): + data = np.random.rand(2, 3).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + idx0 = [1, 0] + idx1 = [0, 1] + idx2 = [0, 0] + idx3 = [1, 1] + + out0 = x[idx0] + out1 = x[idx1] + out2 = x[idx2] + out3 = x[idx3] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out0, out1, out2, out3]) + + expected = [data[idx0], data[idx1], data[idx2], data[idx3]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + + def _test_slice_index_ellipsis(self, place): + data = np.random.rand(2, 3, 4).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + out1 = x[0:, ..., 1:] + out2 = x[0:, ...] + out3 = x[..., 1:] + out4 = x[...] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out1, out2, out3, out4]) + + expected = [data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + + with self.assertRaises(IndexError): + res = x[[1, 0], [0, 0]] + + with self.assertRaises(TypeError): + res = x[[1.2, 0]] + + def _test_slice_index_list_bool(self, place): + data = np.random.rand(2, 3).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + idx0 = [True, False] + idx1 = [False, True] + idx2 = [False, False] + idx3 = [True, True] + + out0 = x[idx0] + out1 = x[idx1] + out2 = x[idx2] + out3 = x[idx3] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out0, out1, out2, out3]) + + expected = [data[idx0], data[idx1], data[idx2], data[idx3]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + + with self.assertRaises(TypeError): + res = x[[True, 0]] + def test_slice(self): + places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): - self._test_slice(core.CUDAPlace(0)) + places.append(core.CUDAPlace(0)) + + for place in places: + self._test_slice(place) + self._test_slice_index_tensor(place) + self._test_slice_index_list(place) + self._test_slice_index_ellipsis(place) + self._test_slice_index_list_bool(place) def _tostring(self): b = default_main_program().current_block() @@ -229,5 +346,61 @@ class TestVariable(unittest.TestCase): self.assertRaises(Exception, _test) +class TestVariableSlice(unittest.TestCase): + def _test_item_none(self, place): + data = np.random.rand(2, 3, 4).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + out0 = x[0:, None, 1:] + out1 = x[0:, None] + out2 = x[None, 1:] + out3 = x[None] + + outs = [out0, out1, out2, out3] + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=outs) + + expected = [ + data[0:, None, 1:], data[0:, None], data[None, 1:], data[None] + ] + for i in range(len(outs)): + self.assertEqual(outs[i].shape, expected[i].shape) + self.assertTrue((result[i] == expected[i]).all()) + + def _test_item_none_and_decrease(self, place): + data = np.random.rand(2, 3, 4).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + out0 = x[0, 1:, None] + out1 = x[0, None] + out2 = x[None, 1] + out3 = x[None] + out4 = x[0, 0, 0, None] + out5 = x[None, 0, 0, 0, None] + + outs = [out0, out1, out2, out3, out4, out5] + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=outs) + expected = [ + data[0, 1:, None], data[0, None], data[None, 1], data[None], + data[0, 0, 0, None], data[None, 0, 0, 0, None] + ] + + for i in range(len(outs)): + self.assertEqual(outs[i].shape, expected[i].shape) + self.assertTrue((result[i] == expected[i]).all()) + + def test_slice(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + + for place in places: + self._test_item_none(place) + self._test_item_none_and_decrease(place) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py new file mode 100644 index 00000000000..c6ddba7fead --- /dev/null +++ b/python/paddle/fluid/variable_index.py @@ -0,0 +1,390 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np +from . import unique_name +from . import core + +MAX_INTEGER = 2**31 - 1 + + +def replace_ellipsis(var, item): + from .framework import Variable + # Use slice(None) to replace Ellipsis. + # For var, var.shape = [3,4,5,6] + # + # var[..., 1:2] -> var[:, :, :, 1:2] + # var[0, ...] -> var[0] + # var[0, ..., 1:2] -> var[0, :, :, 1:2] + + item = list(item) + + # Remove Variable to skip bug when counting Ellipsis + item_remove_var = [ele for ele in item if not isinstance(ele, Variable)] + ell_count = item_remove_var.count(Ellipsis) + if ell_count == 0: + return item + elif ell_count > 1: + raise IndexError("An index can only have a single ellipsis ('...')") + + ell_idx = item.index(Ellipsis) + + if ell_idx == len(item) - 1: + return item[:-1] + else: + item[ell_idx:ell_idx + 1] = [slice(None)] * ( + len(var.shape) - len(item) + 1) + + return item + + +def replace_none(item): + new_item = [] + none_axes = [] + for i, slice_item in enumerate(item): + if slice_item is None: + none_axes.append(i) + else: + new_item.append(slice_item) + return new_item, none_axes + + +def is_integer_or_scalar_tensor(ele): + from .framework import Variable + if isinstance(ele, int): + return True + elif isinstance(ele, Variable): + if len(ele.shape) == 1 and ele.shape[0] == 1: + return True + return False + + +def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags): + from .framework import Variable + from .layers import utils + + if utils._contain_var(attr): + inputs[tensor_attr_name] = utils._convert_to_tensor_list( + attr, dtype="int64") + for i, dim in enumerate(attr): + if isinstance(dim, Variable): + attrs[attr_name].append(-1) + infer_flags[i] = -1 + else: + attrs[attr_name].append(dim) + else: + attrs[attr_name] = attr + + +def _getitem_impl_(var, item): + """ + Slice the variable. + + Args: + item(int/slice/tuple) : the index. + + Returns: + Sliced variable + """ + from .framework import default_main_program, Variable + + if not isinstance(item, tuple): + item = (item, ) + + decrease_axes = [] + axes = [] + starts = [] + ends = [] + steps = [] + reverse_axes = [] + + use_strided_slice = False + item, none_axes = replace_none(item) + item = replace_ellipsis(var, item) + + for dim, slice_item in enumerate(item): + if is_integer_or_scalar_tensor(slice_item): + decrease_axes.append(dim) + start = slice_item + step = 1 + end = slice_item + 1 if slice_item != -1 else MAX_INTEGER + + elif isinstance(slice_item, slice): + start = slice_item.start + end = slice_item.stop + step = slice_item.step + + if start is None and end is None and step is None: + continue + + step = 1 if step is None else step + + if start is None and end is None: + assert (step == -1) + reverse_axes.append(dim) + continue + + start = 0 if start is None else start + end = MAX_INTEGER if end is None else end + + elif isinstance(slice_item, list): + is_bool_list = False + for i in slice_item: + if not isinstance(i, (int, bool)): + raise TypeError("Only support int or bool in index list.") + + if isinstance(i, bool): + is_bool_list = True + break + + if len(item) != 1: + raise IndexError( + "When index contains a list, its length must be 1, but received {}". + format(len(item))) + + if is_bool_list: + new_slice_item = [] + for idx, ele in enumerate(slice_item): + if not isinstance(ele, bool): + raise TypeError( + "Mixed bool index with other types is not supported." + ) + + if ele is True: + new_slice_item.append(idx) + slice_item = new_slice_item + + from .layers import assign + from ..tensor import index_select + + idx = assign(np.array(slice_item).astype("int32")) + return index_select(var, index=idx, axis=0) + + elif isinstance(slice_item, Variable): + if len(item) != 1: + raise IndexError( + "When index contains a Tensor, its length must be 1, but received {}". + format(len(item))) + + from ..tensor import index_select + return index_select(var, index=slice_item, axis=0) + + else: + raise IndexError( + "Valid index accept int or slice or ellipsis, but received {}.". + format(slice_item)) + + axes.append(dim) + starts.append(start) + ends.append(end) + steps.append(step) + use_strided_slice = True if step != 1 else use_strided_slice + + inputs = {'Input': [var]} + attrs = { + 'axes': axes, + 'starts': [], + 'ends': [], + 'decrease_axis': decrease_axes + } + if use_strided_slice: + attrs['strides'] = [] + + infer_flags = [1] * len(axes) + deal_attrs(attrs, starts, "starts", "StartsTensorList", inputs, infer_flags) + deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags) + deal_attrs(attrs, steps, "strides", "StridesTensorList", inputs, + infer_flags) + attrs['infer_flags'] = infer_flags + + out = var + if len(axes) > 0: + target_block = default_main_program().current_block() + op_type = "strided_slice" if use_strided_slice else "slice" + + slice_out_var = target_block.create_var( + name=unique_name.generate_with_ignorable_key(var.name + "_" + + op_type), + dtype=var.dtype) + target_block.append_op( + type=op_type, + inputs=inputs, + outputs={'Out': [slice_out_var]}, + attrs=attrs) + out = slice_out_var + + if len(reverse_axes) > 0: + from .layers.tensor import reverse + out = reverse(out, axis=reverse_axes) + + # Deal with cases when all axes are decreased. + # After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar. + # In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased. + # For example: + # # x.shape: (2,3,4) + # out = x[0, 1, 1, None] # out.shape : (1) + if len(decrease_axes) == len(var.shape): + none_axes = none_axes[1:] + + if len(none_axes) > 0: + # Deal with cases that decrease_axes is not empty + # For example: + # # x.shape: (2,3,4) + # out = x[0, 0:2, None] # out.shape : (2, 1, 4) + for idx, axis in enumerate(none_axes): + l = len([i for i in decrease_axes if i < axis]) + new_axis = axis - l + none_axes[idx] = new_axis + + # Deal with cases when all axes are decreased. + # After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar. + # In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased. + # For example: + # # x.shape: (2,3,4) + # out = x[0, 1, 1, None] # out.shape : (1) + + from ..tensor import unsqueeze + out = unsqueeze(out, axis=none_axes) + + return out + + +def _setitem_impl_(var, item, value): + from .framework import default_main_program, Variable + + inputs = {'Input': var} + + # 1. Parse item + if not isinstance(item, tuple): + item = (item, ) + + decrease_axes = [] + axes = [] + starts = [] + ends = [] + steps = [] + + item = replace_ellipsis(var, item) + + for dim, slice_item in enumerate(item): + if is_integer_or_scalar_tensor(slice_item): + decrease_axes.append(dim) + start = slice_item + end = slice_item + 1 if slice_item != -1 else MAX_INTEGER + step = 1 + + elif isinstance(slice_item, slice): + start = slice_item.start + end = slice_item.stop + step = slice_item.step + + if start is None and end is None and step is None: + continue + + step = 1 if step is None else step + + if not isinstance(step, Variable) and step == 0: + raise ValueError( + "When assign a value to a paddle.Tensor, step can not be 0, " + "but received step is {}.".format(step)) + + if isinstance(step, Variable) and (start is None or end is None): + raise ValueError( + "When assign a value to a paddle.Tensor, it's not supported that " + "the start or end is None when the type of step is paddle.Tensor." + ) + + if start is None: + start = 0 if step > 0 else MAX_INTEGER + + if end is None: + end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER) + else: + raise IndexError( + "Valid index accept int or slice or ellipsis, but received {}.". + format(slice_item)) + + axes.append(dim) + starts.append(start) + ends.append(end) + steps.append(step) + + attrs = { + 'axes': axes, + 'starts': starts, + 'ends': ends, + 'steps': steps, + 'decrease_axes': decrease_axes + } + + from .layers import utils + if utils._contain_var(starts): + inputs['StartsTensorList'] = utils._convert_to_tensor_list(starts) + del attrs['starts'] + if utils._contain_var(ends): + inputs['EndsTensorList'] = utils._convert_to_tensor_list(ends) + del attrs['ends'] + if utils._contain_var(steps): + inputs['StepsTensorList'] = utils._convert_to_tensor_list(steps) + del attrs['steps'] + + # 2. Parse value + dtype = var.dtype + attrs['dtype'] = dtype + + from .data_feeder import convert_dtype + # 2.1 value is an integer of float + if isinstance(value, (int, float)): + value = np.array([value]).astype(convert_dtype(dtype)) + + # 2.2 value is a np.ndarray + if isinstance(value, np.ndarray): + shape = list(value.shape) + if dtype == core.VarDesc.VarType.BOOL: + value_name = "bool_values" + values = [bool(v) for v in value.flat] + elif dtype == core.VarDesc.VarType.FP32: + value_name = "fp32_values" + values = [float(v) for v in value.flat] + elif dtype == core.VarDesc.VarType.FP64: + value_name = "fp64_values" + values = [float(v) for v in value.flat] + elif dtype == core.VarDesc.VarType.INT32: + value_name = "int32_values" + values = [int(v) for v in value.flat] + elif dtype == core.VarDesc.VarType.INT64: + value_name = "int64_values" + values = [int(v) for v in value.flat] + else: + raise TypeError( + "When assign a numpy.ndarray, integer or float to a paddle.Tensor, " + "the data type of the paddle.Tensor must be bool, float32, int32 or int64, but " + "received %s." % convert_dtype(dtype)) + attrs[value_name] = values + attrs["shape"] = shape + + elif isinstance(value, Variable): + inputs["ValueTensor"] = value + else: + raise TypeError( + "Only support to assign an integer, float, numpy.ndarray or " + "paddle.Tensor to a paddle.Tensor, but received {}".format( + type(value))) + + cur_block = default_main_program().current_block() + cur_block.append_op( + type="set_value", inputs=inputs, outputs={'Out': var}, attrs=attrs) + + return var -- GitLab