未验证 提交 2b44ae5d 编写于 作者: L liym27 提交者: GitHub

[cherry-pick] Polish code for setitem/getitem and support index for...

[cherry-pick] Polish code for setitem/getitem and support index for list/Tensor/None/Ellipsis/bool (#33528)

* [cherry-pick 2.1] Polish code for _getitem_impl_ (#32868)

* [cherry-pick] Polish code for setitem and getitem (#32911)

* [slice getitem] Support getitem idx is Tensor or List (#33000)

* [getitem] Support index is None for getitem in static mode (#33001)

* [Static getitem] Support static Variable getitem for Ellipsis index (#32876)

* [static getitem]Support index is list bool for getitem in static mode (#33298)
上级 bbedca46
...@@ -39,6 +39,7 @@ from . import unique_name ...@@ -39,6 +39,7 @@ from . import unique_name
import paddle.version as fluid_version import paddle.version as fluid_version
import warnings import warnings
import functools import functools
from .variable_index import _getitem_impl_, _setitem_impl_
__all__ = [ __all__ = [
'Program', 'Program',
...@@ -794,205 +795,6 @@ class ParameterMetaClass(VariableMetaClass): ...@@ -794,205 +795,6 @@ class ParameterMetaClass(VariableMetaClass):
return issubclass(t, Parameter) return issubclass(t, Parameter)
def _getitem_impl_(var, item):
"""
Slice the variable.
Args:
item(int/slice/tuple) : the index.
Returns:
Sliced variable
"""
if not isinstance(item, tuple):
item = [item]
decrease_axis = []
slice_axis = []
slice_start = []
slice_end = []
slice_step = []
use_strided_slice = False
reverse_axis = []
target_block = default_main_program().current_block()
def fill_constant(shape, value, force_cpu=False, out=None):
var.block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
})
out.stop_gradient = True
return out
for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step
if start is None and end is None and step is None:
continue
if step is None:
step = 1
if start is None and end is None:
assert (step == -1)
reverse_axis.append(dim)
continue
if start is None:
start = 0
if end is None:
end = 10000000
if step != 1:
use_strided_slice = True
slice_axis.append(dim)
slice_start.append(start)
slice_end.append(end)
slice_step.append(step)
else:
decrease_axis.append(dim)
slice_axis.append(dim)
slice_start.append(slice_item)
slice_step.append(1)
if isinstance(slice_item, Variable):
temp_1 = var.block.create_var(dtype=slice_item.dtype)
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = target_block.create_var(dtype=slice_item.dtype)
target_block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
'Y': temp_1},
outputs={'Out': temp_end},
attrs={'axis': -1})
slice_end.append(temp_end)
else:
slice_end.append(slice_item + 1
if slice_item != -1 else 10000000)
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = var.block.create_var(dtype='int64')
fill_constant([1], dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': [var]}
attrs = {
'axes': slice_axis,
'starts': [],
'ends': [],
'decrease_axis': decrease_axis
}
if (use_strided_slice == True):
attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis)))
# starts
if contain_var(slice_start):
inputs['StartsTensorList'] = get_new_list_tensor(slice_start)
for i, dim in enumerate(slice_start):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
else:
attrs['starts'] = slice_start
# ends
if contain_var(slice_end):
inputs['EndsTensorList'] = get_new_list_tensor(slice_end)
for i, dim in enumerate(slice_end):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
else:
attrs['ends'] = slice_end
# strides
if use_strided_slice == True:
if contain_var(slice_step):
inputs['StridesTensorList'] = get_new_list_tensor(slice_step)
for i, dim in enumerate(slice_step):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
else:
attrs['strides'] = slice_step
# infer_flags
attrs['infer_flags'] = infer_flags
out = var
if use_strided_slice == False and len(slice_axis) > 0:
# append slice_op here
slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_slice"),
dtype=var.dtype)
target_block.append_op(
type="slice",
inputs=inputs,
outputs={'Out': [slice_out_var]},
attrs=attrs)
out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0:
strided_slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_strided_slice"),
dtype=var.dtype)
target_block.append_op(
type="strided_slice",
inputs=inputs,
outputs={'Out': [strided_slice_out_var]},
attrs=attrs)
out = strided_slice_out_var
if len(reverse_axis) > 0:
reverse_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_slice_reverse"),
dtype=var.dtype)
target_block.append_op(
type="reverse",
inputs={'X': out},
outputs={'Out': [reverse_out_var]},
attrs={'axis': reverse_axis})
out = reverse_out_var
return out
@six.add_metaclass(VariableMetaClass) @six.add_metaclass(VariableMetaClass)
class Variable(object): class Variable(object):
""" """
...@@ -1848,160 +1650,7 @@ class Variable(object): ...@@ -1848,160 +1650,7 @@ class Variable(object):
return _getitem_impl_(self, item) return _getitem_impl_(self, item)
def __setitem__(self, item, value): def __setitem__(self, item, value):
inputs = {'Input': self} return _setitem_impl_(self, item, value)
# 1. Parse item
if not isinstance(item, tuple):
item = [item]
decrease_axes = []
axes = []
starts = []
ends = []
steps = []
max_integer = sys.maxsize
def replace_ellipsis(item):
# Use slice(None) to replace Ellipsis.
# For var, var.shape = [3,4,5,6]
#
# var[..., 1:2] -> var[:, :, :, 1:2]
# var[0, ...] -> var[0]
# var[0, ..., 1:2] -> var[0, :, :, 1:2]
item = list(item)
# Remove Variable to skip bug when counting Ellipsis
item_remove_var = [
ele for ele in item if not isinstance(ele, Variable)
]
ell_count = item_remove_var.count(Ellipsis)
if ell_count == 0:
return item
elif ell_count > 1:
raise IndexError(
"An index can only have a single ellipsis ('...')")
ell_idx = item.index(Ellipsis)
if ell_idx == len(item) - 1:
return item[:-1]
else:
item[ell_idx:ell_idx + 1] = [slice(None)] * (
len(self.shape) - len(item) + 1)
return item
item = replace_ellipsis(item)
for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step
if start is None and end is None and step is None:
continue
step = 1 if step is None else step
# TODO: support cases when step < 1
if not isinstance(step, Variable) and step == 0:
raise ValueError(
"When assign a value to a paddle.Tensor, step can not be 0, "
"but received step is {}.".format(step))
if isinstance(step, Variable) and (start is None or
end is None):
raise ValueError(
"When assign a value to a paddle.Tensor, it's not supported that "
"the start or end is None when the type of step is paddle.Tensor."
)
if start is None:
start = 0 if step > 0 else max_integer
if end is None:
end = max_integer if step > 0 else (0 - max_integer)
else:
decrease_axes.append(dim)
start = slice_item
end = slice_item + 1 if slice_item != -1 else max_integer
step = 1
axes.append(dim)
starts.append(start)
ends.append(end)
steps.append(step)
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes
}
from .layers import utils
if utils._contain_var(starts):
inputs['StartsTensorList'] = utils._convert_to_tensor_list(starts)
del attrs['starts']
if utils._contain_var(ends):
inputs['EndsTensorList'] = utils._convert_to_tensor_list(ends)
del attrs['ends']
if utils._contain_var(steps):
inputs['StepsTensorList'] = utils._convert_to_tensor_list(steps)
del attrs['steps']
# 2. Parse value
dtype = self.dtype
attrs['dtype'] = dtype
from .data_feeder import convert_dtype
# 2.1 value is an integer of float
if isinstance(value, (int, float)):
value = np.array([value]).astype(convert_dtype(dtype))
# 2.2 value is a np.ndarray
if isinstance(value, np.ndarray):
shape = list(value.shape)
if dtype == core.VarDesc.VarType.BOOL:
value_name = "bool_values"
values = [bool(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP64:
value_name = "fp64_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in value.flat]
else:
raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32 or int64, but "
"received %s." % convert_dtype(dtype))
attrs[value_name] = values
attrs["shape"] = shape
elif isinstance(value, Variable):
inputs["ValueTensor"] = value
else:
raise TypeError(
"Only support to assign an integer, float, numpy.ndarray or "
"paddle.Tensor to a paddle.Tensor, but received {}".format(
type(value)))
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value", inputs=inputs, outputs={'Out': self}, attrs=attrs)
return self
def get_value(self, scope=None): def get_value(self, scope=None):
""" """
......
...@@ -15,12 +15,16 @@ ...@@ -15,12 +15,16 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import paddle
from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import numpy as np
paddle.enable_static()
class TestVariable(unittest.TestCase): class TestVariable(unittest.TestCase):
def test_np_dtype_convert(self): def test_np_dtype_convert(self):
...@@ -161,12 +165,125 @@ class TestVariable(unittest.TestCase): ...@@ -161,12 +165,125 @@ class TestVariable(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
def test_slice(self): def _test_slice_index_tensor(self, place):
place = fluid.CPUPlace() data = np.random.rand(2, 3).astype("float32")
self._test_slice(place) prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [1, 0]
idx1 = [0, 1]
idx2 = [0, 0]
idx3 = [1, 1]
out0 = x[paddle.assign(np.array(idx0))]
out1 = x[paddle.assign(np.array(idx1))]
out2 = x[paddle.assign(np.array(idx2))]
out3 = x[paddle.assign(np.array(idx3))]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])
expected = [data[idx0], data[idx1], data[idx2], data[idx3]]
self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
with self.assertRaises(IndexError):
one = paddle.ones(shape=[1])
res = x[one, [0, 0]]
def _test_slice_index_list(self, place):
data = np.random.rand(2, 3).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [1, 0]
idx1 = [0, 1]
idx2 = [0, 0]
idx3 = [1, 1]
out0 = x[idx0]
out1 = x[idx1]
out2 = x[idx2]
out3 = x[idx3]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])
expected = [data[idx0], data[idx1], data[idx2], data[idx3]]
self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
def _test_slice_index_ellipsis(self, place):
data = np.random.rand(2, 3, 4).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
out1 = x[0:, ..., 1:]
out2 = x[0:, ...]
out3 = x[..., 1:]
out4 = x[...]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out1, out2, out3, out4])
expected = [data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...]]
self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
with self.assertRaises(IndexError):
res = x[[1, 0], [0, 0]]
with self.assertRaises(TypeError):
res = x[[1.2, 0]]
def _test_slice_index_list_bool(self, place):
data = np.random.rand(2, 3).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [True, False]
idx1 = [False, True]
idx2 = [False, False]
idx3 = [True, True]
out0 = x[idx0]
out1 = x[idx1]
out2 = x[idx2]
out3 = x[idx3]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])
expected = [data[idx0], data[idx1], data[idx2], data[idx3]]
self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
with self.assertRaises(TypeError):
res = x[[True, 0]]
def test_slice(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self._test_slice(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
for place in places:
self._test_slice(place)
self._test_slice_index_tensor(place)
self._test_slice_index_list(place)
self._test_slice_index_ellipsis(place)
self._test_slice_index_list_bool(place)
def _tostring(self): def _tostring(self):
b = default_main_program().current_block() b = default_main_program().current_block()
...@@ -229,5 +346,61 @@ class TestVariable(unittest.TestCase): ...@@ -229,5 +346,61 @@ class TestVariable(unittest.TestCase):
self.assertRaises(Exception, _test) self.assertRaises(Exception, _test)
class TestVariableSlice(unittest.TestCase):
def _test_item_none(self, place):
data = np.random.rand(2, 3, 4).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
out0 = x[0:, None, 1:]
out1 = x[0:, None]
out2 = x[None, 1:]
out3 = x[None]
outs = [out0, out1, out2, out3]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=outs)
expected = [
data[0:, None, 1:], data[0:, None], data[None, 1:], data[None]
]
for i in range(len(outs)):
self.assertEqual(outs[i].shape, expected[i].shape)
self.assertTrue((result[i] == expected[i]).all())
def _test_item_none_and_decrease(self, place):
data = np.random.rand(2, 3, 4).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
out0 = x[0, 1:, None]
out1 = x[0, None]
out2 = x[None, 1]
out3 = x[None]
out4 = x[0, 0, 0, None]
out5 = x[None, 0, 0, 0, None]
outs = [out0, out1, out2, out3, out4, out5]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=outs)
expected = [
data[0, 1:, None], data[0, None], data[None, 1], data[None],
data[0, 0, 0, None], data[None, 0, 0, 0, None]
]
for i in range(len(outs)):
self.assertEqual(outs[i].shape, expected[i].shape)
self.assertTrue((result[i] == expected[i]).all())
def test_slice(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self._test_item_none(place)
self._test_item_none_and_decrease(place)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from . import unique_name
from . import core
MAX_INTEGER = 2**31 - 1
def replace_ellipsis(var, item):
from .framework import Variable
# Use slice(None) to replace Ellipsis.
# For var, var.shape = [3,4,5,6]
#
# var[..., 1:2] -> var[:, :, :, 1:2]
# var[0, ...] -> var[0]
# var[0, ..., 1:2] -> var[0, :, :, 1:2]
item = list(item)
# Remove Variable to skip bug when counting Ellipsis
item_remove_var = [ele for ele in item if not isinstance(ele, Variable)]
ell_count = item_remove_var.count(Ellipsis)
if ell_count == 0:
return item
elif ell_count > 1:
raise IndexError("An index can only have a single ellipsis ('...')")
ell_idx = item.index(Ellipsis)
if ell_idx == len(item) - 1:
return item[:-1]
else:
item[ell_idx:ell_idx + 1] = [slice(None)] * (
len(var.shape) - len(item) + 1)
return item
def replace_none(item):
new_item = []
none_axes = []
for i, slice_item in enumerate(item):
if slice_item is None:
none_axes.append(i)
else:
new_item.append(slice_item)
return new_item, none_axes
def is_integer_or_scalar_tensor(ele):
from .framework import Variable
if isinstance(ele, int):
return True
elif isinstance(ele, Variable):
if len(ele.shape) == 1 and ele.shape[0] == 1:
return True
return False
def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags):
from .framework import Variable
from .layers import utils
if utils._contain_var(attr):
inputs[tensor_attr_name] = utils._convert_to_tensor_list(
attr, dtype="int64")
for i, dim in enumerate(attr):
if isinstance(dim, Variable):
attrs[attr_name].append(-1)
infer_flags[i] = -1
else:
attrs[attr_name].append(dim)
else:
attrs[attr_name] = attr
def _getitem_impl_(var, item):
"""
Slice the variable.
Args:
item(int/slice/tuple) : the index.
Returns:
Sliced variable
"""
from .framework import default_main_program, Variable
if not isinstance(item, tuple):
item = (item, )
decrease_axes = []
axes = []
starts = []
ends = []
steps = []
reverse_axes = []
use_strided_slice = False
item, none_axes = replace_none(item)
item = replace_ellipsis(var, item)
for dim, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item):
decrease_axes.append(dim)
start = slice_item
step = 1
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
elif isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step
if start is None and end is None and step is None:
continue
step = 1 if step is None else step
if start is None and end is None:
assert (step == -1)
reverse_axes.append(dim)
continue
start = 0 if start is None else start
end = MAX_INTEGER if end is None else end
elif isinstance(slice_item, list):
is_bool_list = False
for i in slice_item:
if not isinstance(i, (int, bool)):
raise TypeError("Only support int or bool in index list.")
if isinstance(i, bool):
is_bool_list = True
break
if len(item) != 1:
raise IndexError(
"When index contains a list, its length must be 1, but received {}".
format(len(item)))
if is_bool_list:
new_slice_item = []
for idx, ele in enumerate(slice_item):
if not isinstance(ele, bool):
raise TypeError(
"Mixed bool index with other types is not supported."
)
if ele is True:
new_slice_item.append(idx)
slice_item = new_slice_item
from .layers import assign
from ..tensor import index_select
idx = assign(np.array(slice_item).astype("int32"))
return index_select(var, index=idx, axis=0)
elif isinstance(slice_item, Variable):
if len(item) != 1:
raise IndexError(
"When index contains a Tensor, its length must be 1, but received {}".
format(len(item)))
from ..tensor import index_select
return index_select(var, index=slice_item, axis=0)
else:
raise IndexError(
"Valid index accept int or slice or ellipsis, but received {}.".
format(slice_item))
axes.append(dim)
starts.append(start)
ends.append(end)
steps.append(step)
use_strided_slice = True if step != 1 else use_strided_slice
inputs = {'Input': [var]}
attrs = {
'axes': axes,
'starts': [],
'ends': [],
'decrease_axis': decrease_axes
}
if use_strided_slice:
attrs['strides'] = []
infer_flags = [1] * len(axes)
deal_attrs(attrs, starts, "starts", "StartsTensorList", inputs, infer_flags)
deal_attrs(attrs, ends, "ends", "EndsTensorList", inputs, infer_flags)
deal_attrs(attrs, steps, "strides", "StridesTensorList", inputs,
infer_flags)
attrs['infer_flags'] = infer_flags
out = var
if len(axes) > 0:
target_block = default_main_program().current_block()
op_type = "strided_slice" if use_strided_slice else "slice"
slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_" +
op_type),
dtype=var.dtype)
target_block.append_op(
type=op_type,
inputs=inputs,
outputs={'Out': [slice_out_var]},
attrs=attrs)
out = slice_out_var
if len(reverse_axes) > 0:
from .layers.tensor import reverse
out = reverse(out, axis=reverse_axes)
# Deal with cases when all axes are decreased.
# After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar.
# In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased.
# For example:
# # x.shape: (2,3,4)
# out = x[0, 1, 1, None] # out.shape : (1)
if len(decrease_axes) == len(var.shape):
none_axes = none_axes[1:]
if len(none_axes) > 0:
# Deal with cases that decrease_axes is not empty
# For example:
# # x.shape: (2,3,4)
# out = x[0, 0:2, None] # out.shape : (2, 1, 4)
for idx, axis in enumerate(none_axes):
l = len([i for i in decrease_axes if i < axis])
new_axis = axis - l
none_axes[idx] = new_axis
# Deal with cases when all axes are decreased.
# After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar.
# In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased.
# For example:
# # x.shape: (2,3,4)
# out = x[0, 1, 1, None] # out.shape : (1)
from ..tensor import unsqueeze
out = unsqueeze(out, axis=none_axes)
return out
def _setitem_impl_(var, item, value):
from .framework import default_main_program, Variable
inputs = {'Input': var}
# 1. Parse item
if not isinstance(item, tuple):
item = (item, )
decrease_axes = []
axes = []
starts = []
ends = []
steps = []
item = replace_ellipsis(var, item)
for dim, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item):
decrease_axes.append(dim)
start = slice_item
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
step = 1
elif isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step
if start is None and end is None and step is None:
continue
step = 1 if step is None else step
if not isinstance(step, Variable) and step == 0:
raise ValueError(
"When assign a value to a paddle.Tensor, step can not be 0, "
"but received step is {}.".format(step))
if isinstance(step, Variable) and (start is None or end is None):
raise ValueError(
"When assign a value to a paddle.Tensor, it's not supported that "
"the start or end is None when the type of step is paddle.Tensor."
)
if start is None:
start = 0 if step > 0 else MAX_INTEGER
if end is None:
end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER)
else:
raise IndexError(
"Valid index accept int or slice or ellipsis, but received {}.".
format(slice_item))
axes.append(dim)
starts.append(start)
ends.append(end)
steps.append(step)
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes
}
from .layers import utils
if utils._contain_var(starts):
inputs['StartsTensorList'] = utils._convert_to_tensor_list(starts)
del attrs['starts']
if utils._contain_var(ends):
inputs['EndsTensorList'] = utils._convert_to_tensor_list(ends)
del attrs['ends']
if utils._contain_var(steps):
inputs['StepsTensorList'] = utils._convert_to_tensor_list(steps)
del attrs['steps']
# 2. Parse value
dtype = var.dtype
attrs['dtype'] = dtype
from .data_feeder import convert_dtype
# 2.1 value is an integer of float
if isinstance(value, (int, float)):
value = np.array([value]).astype(convert_dtype(dtype))
# 2.2 value is a np.ndarray
if isinstance(value, np.ndarray):
shape = list(value.shape)
if dtype == core.VarDesc.VarType.BOOL:
value_name = "bool_values"
values = [bool(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP64:
value_name = "fp64_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in value.flat]
else:
raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32 or int64, but "
"received %s." % convert_dtype(dtype))
attrs[value_name] = values
attrs["shape"] = shape
elif isinstance(value, Variable):
inputs["ValueTensor"] = value
else:
raise TypeError(
"Only support to assign an integer, float, numpy.ndarray or "
"paddle.Tensor to a paddle.Tensor, but received {}".format(
type(value)))
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value", inputs=inputs, outputs={'Out': var}, attrs=attrs)
return var
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册