From 5b367dab442f8bf8b9eba83535a25ea206e38632 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Sat, 20 Feb 2021 14:33:24 +0800 Subject: [PATCH] [static setitem] Support the index is Tensor; step>1; step<0 .(#30949) * [static setitem] support the index step > 1. tensor_a[::3] = value * [static setitem] support the index step < 0. Eg: tensor_a[::-3] = value * [static setitem] support the index is Tensor. eg: tensor_a[tensor_3:0:-1] = value * Add op version. --- paddle/fluid/operators/set_value_op.cc | 82 +++++++- paddle/fluid/operators/set_value_op.h | 141 +++++++++----- paddle/fluid/pybind/imperative.cc | 44 ++--- python/paddle/fluid/framework.py | 46 ++++- .../tests/unittests/test_set_value_op.py | 177 +++++++++++++++++- 5 files changed, 403 insertions(+), 87 deletions(-) diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index 699aa5dad5..a18238adca 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/fluid/operators/set_value_op.h" - #include +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -60,18 +60,52 @@ class SetValue : public framework::OperatorWithKernel { framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "StartsTensorList" || var_name == "EndsTensorList" || + var_name == "StepsTensorList") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } }; class SetValueMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { + // Input AddInput("Input", "(Tensor) Input tensor of set_value operator."); AddInput("ValueTensor", "(Tensor) Value tensor of set_value operator.") .AsDispensable(); + AddInput("StartsTensorList", + "(vector>, optional) If provided, set_value will " + "use this. The shape of the tensor in vector must be [1]." + "It has higher priority compare with attr(starts).") + .AsDuplicable() + .AsDispensable(); + AddInput("EndsTensorList", + "(vector>, optional) If provided, set_value will " + "use this. The shape of the tensor in vector must BE [1]." + "It has higher priority compare with attr(ends).") + .AsDuplicable() + .AsDispensable(); + + AddInput("StepsTensorList", + "(vector>, optional) If provided, set_value will " + "use this. The shape of the tensor in vector must BE [1]." + "It has higher priority compare with attr(steps).") + .AsDuplicable() + .AsDispensable(); + + // Output AddOutput("Out", "(Tensor) Output tensor of set_value operator. The output is the " "same Tensor as input"); + // Attr AddAttr("dtype", "data type of input.") .InEnum( {framework::proto::VarType::BOOL, framework::proto::VarType::INT32, @@ -82,20 +116,25 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { "axes", "(list) Axes that `starts` and `ends` apply to."); AddAttr>( "starts", - "(list) Starting indices of corresponding axis in `axes`"); + "(list) Starting indices of corresponding axis in `axes`.") + .SetDefault({}); AddAttr>( "ends", - "(list) Ending indices of corresponding axis in `axes`."); + "(list) Ending indices of corresponding axis in `axes`.") + .SetDefault({}); + AddAttr>( + "steps", "(list) Stride step from the start to the end.") + .SetDefault({}); - AddAttr>("bool_values", "store the bool values") + AddAttr>("bool_values", "Store the bool values.") .SetDefault({}); - AddAttr>("fp32_values", "store the float32 values") + AddAttr>("fp32_values", "Store the float32 values.") .SetDefault({}); - AddAttr>("int32_values", "store the int32 values") + AddAttr>("int32_values", "Store the int32 values.") .SetDefault({}); - AddAttr>("int64_values", "store the int64 values") + AddAttr>("int64_values", "Store the int64 values.") .SetDefault({}); - AddAttr>("fp64_values", "store the float64 values") + AddAttr>("fp64_values", "Store the float64 values.") .SetDefault({}); AddAttr>("shape", "(vector) Shape of values.") @@ -121,3 +160,30 @@ REGISTER_OP_CPU_KERNEL( ops::SetValueKernel, ops::SetValueKernel, ops::SetValueKernel); + +REGISTER_OP_VERSION(set_value) + .AddCheckpoint( + R"ROC( +Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorList] and 1 attribute [steps]. + )ROC", + paddle::framework::compatible::OpVersionDesc() + .NewInput("StartsTensorList", + "If provided, set_value will use this.The shape of the " + "tensor in vector must be [1]. It has higher priority " + "compare with attr(starts).") + .NewInput("EndsTensorList", + "If provided, set_value will use this.The shape of the " + "tensor in vector must be [1]. It has higher priority " + "compare with attr(ends).") + .NewInput("StepsTensorList", + "If provided, set_value will use this.The shape of the " + "tensor in vector must be [1]. It has higher priority " + "compare with attr(steps).") + .ModifyAttr("starts", + "Starting indices of corresponding axis in `axes`.", + std::vector{}) + .ModifyAttr("ends", + "Ending indices of corresponding axis in `axes`.", + std::vector{}) + .NewAttr("steps", "Stride step from the start to the end.", + std::vector{})); diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h index 558a8276ce..6347bcd247 100644 --- a/paddle/fluid/operators/set_value_op.h +++ b/paddle/fluid/operators/set_value_op.h @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/assign_value_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -58,26 +59,70 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) { return value_name; } +inline void CheckAndUpdateSlice(const framework::DDim in_dims, + const std::vector axes, + std::vector* starts, + std::vector* ends, + std::vector* steps) { + for (size_t i = 0; i < axes.size(); ++i) { + int64_t axis = axes[i]; + int64_t dim_value = in_dims[axis]; + + int64_t start = + (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; + int64_t end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; + start = std::max(start, static_cast(0)); + end = std::min(end, dim_value); + + int64_t step = (*steps)[i]; + PADDLE_ENFORCE_NE( + step, 0, platform::errors::InvalidArgument( + "Step should not be 0, but received step = %d.", step)); + if (step > 0) { + start = std::min(start, dim_value); + end = std::max(end, static_cast(0)); + PADDLE_ENFORCE_GT( + end, start, + platform::errors::InvalidArgument( + "When step > 0, end should be greater than start, but " + "received end = %d, start = %d.", + end, start)); + } else { + // NOTE(liym27): When step < 0, start should less and equal to dim_value-1 + // "end is -1" means contain the 0-th element of this axis. + start = std::min(start, dim_value - 1); + end = std::max(end, static_cast(-1)); + PADDLE_ENFORCE_GT( + start, end, + platform::errors::InvalidArgument( + "When step < 0, start should be greater than end, but " + "received start = %d, end = %d.", + start, end)); + } + + (*starts)[i] = start; + (*ends)[i] = end; + } +} + inline framework::DDim GetSliceDims(const framework::DDim in_dims, const std::vector axes, const std::vector starts, - const std::vector ends) { + const std::vector ends, + const std::vector steps) { framework::DDim slice_dims(in_dims); for (size_t i = 0; i < axes.size(); ++i) { int64_t axis = axes[i]; - int64_t dim_value = in_dims[axis]; + int64_t start = starts[i]; + int64_t end = ends[i]; + int64_t step = steps[i]; - int64_t start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; - int64_t end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; - start = std::max(start, static_cast(0)); - end = std::min(end, dim_value); - - PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument( - "end should greater than start, but " - "received end = %d, start = %d", - end, start)); - slice_dims[axis] = end - start; + if (step > 0) { + slice_dims[axis] = (end - start + step - 1) / step; + } else { + slice_dims[axis] = (end - start + step + 1) / step; + } } return slice_dims; } @@ -120,19 +165,36 @@ class SetValueKernel : public framework::OpKernel { template void SetValueCompute(const framework::ExecutionContext& ctx) const { auto* in = ctx.Input("Input"); + auto* value_tensor = ctx.Input("ValueTensor"); auto* out = ctx.Output("Out"); + auto starts_tensor_list = + ctx.MultiInput("StartsTensorList"); + auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); + auto steps_tensor_list = + ctx.MultiInput("StepsTensorList"); + auto dtype = static_cast(ctx.Attr("dtype")); auto axes = ctx.Attr>("axes"); auto starts = ctx.Attr>("starts"); auto ends = ctx.Attr>("ends"); + auto steps = ctx.Attr>("steps"); auto shape = ctx.Attr>("shape"); - auto* value_tensor = ctx.Input("ValueTensor"); + + if (!starts_tensor_list.empty()) { + starts = GetDataFromTensorList(starts_tensor_list); + } + if (!ends_tensor_list.empty()) { + ends = GetDataFromTensorList(ends_tensor_list); + } + if (!steps_tensor_list.empty()) { + steps = GetDataFromTensorList(steps_tensor_list); + } auto in_dims = in->dims(); - auto value_dims = framework::make_ddim(shape); - auto slice_dims = GetSliceDims(in_dims, axes, starts, ends); + CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps); + auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps); auto place = ctx.GetPlace(); auto& eigen_place = @@ -160,46 +222,37 @@ class SetValueKernel : public framework::OpKernel { auto slice_e = framework::EigenTensor::From(slice_t, slice_dims); // Step 1: Set the value of out at `_index` to zero - // - Step 1.1 Get a slice tensor from out - Eigen::array offsets, extents; - Eigen::array, D> paddings; + slice_e.device(eigen_place) = slice_e.constant(T(0)); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); for (size_t i = 0; i < D; ++i) { - offsets[i] = 0; - extents[i] = slice_dims[i]; - } - int64_t start; - for (size_t i = 0; i < axes.size(); ++i) { - start = starts[i] < 0 ? (starts[i] + in_dims[axes[i]]) : starts[i]; - start = std::max(start, static_cast(0)); - offsets[axes[i]] = start; + starts_indices[i] = 0; + ends_indices[i] = slice_dims[i]; + strides_indices[i] = 1; } - for (size_t i = 0; i < paddings.size(); ++i) { - paddings[i].first = offsets[i]; - paddings[i].second = (in_dims[i] - slice_dims[i]) - offsets[i]; + for (size_t i = 0; i < axes.size(); i++) { + int axis_index = axes[i]; + starts_indices[axis_index] = starts[i]; + ends_indices[axis_index] = ends[i]; + strides_indices[axis_index] = steps[i]; } - slice_e.device(eigen_place) = out_e.slice(offsets, extents); - - // - Step 1.2 Get paded tensor by padding 0 to slice tensor - pad_e.device(eigen_place) = slice_e.pad(paddings, T(0)); - - // - Step 1.3 Set 0 at `_index` of out tensor - out_e.device(eigen_place) = out_e - pad_e; + out_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; // Step 2: Set a tensor with the same shape as out tensor. And its data at // '_index' is the same as value_tensor, and data out of '_index' to zero - - // - Step 2.1 Set the data of slice tensor to 0 - slice_e.device(eigen_place) = slice_e.constant(T(0)); - - // - Step 2.2 Set slice tensor with value + // - Step 2.1 Set slice tensor with value if (value_tensor != nullptr) { // ElementwiseComputeEx can do broadcasting ElementwiseComputeEx, DeviceContext, T>( ctx, &slice_t, value_tensor, -1, SubFunctor(), &slice_t); } else { Tensor value_t(dtype); + auto value_dims = framework::make_ddim(shape); value_t.mutable_data(value_dims, place); auto value_name = GetValueName(dtype); CopyVecotorToTensor(value_name.c_str(), &value_t, ctx); @@ -208,8 +261,10 @@ class SetValueKernel : public framework::OpKernel { ctx, &slice_t, &value_t, -1, SubFunctor(), &slice_t); } - // - Step 2.3 Pad slice tensor with 0 - pad_e.device(eigen_place) = slice_e.pad(paddings, T(0)); + // - Step 2.2 Pad slice tensor with 0 + pad_e.device(eigen_place) = pad_e.constant(T(0)); + pad_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; // Step 3: Set out tensor with value_tensor out_e.device(eigen_place) = out_e - pad_e; diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 6d20c86757..8e894fc07a 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -587,8 +587,16 @@ void BindImperative(py::module *m_ptr) { ? PyTuple_Pack(1, _index.ptr()) : _index.ptr(); // 1. Check argumnets - // 1.1 Check whether _index can be parsed. + // 1.1 Check whether value obj is a tensor. + bool value_is_tensor = true; bool parse_index = true; + if (py::isinstance(value_obj) || + py::isinstance(value_obj) || + py::isinstance(value_obj)) { + value_is_tensor = false; + } + + // 1.2 Check whether _index can be parsed. const int size = PyTuple_GET_SIZE(index_ptr); for (int dim = 0; dim < size; ++dim) { PyObject *slice_item = PyTuple_GetItem(index_ptr, dim); @@ -598,34 +606,20 @@ void BindImperative(py::module *m_ptr) { } } - // 1.2 Check whether stride is 1. - std::vector axes, starts, ends, strides, decrease_axis, - infer_flags; - - bool stride_is_1 = true; - if (parse_index) { - ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, - &strides, &decrease_axis, &infer_flags); - stride_is_1 = - std::all_of(strides.cbegin(), strides.cend(), - [](int64_t stride) { return stride == 1; }); - } - - // 1.3 Check whether value obj is a tensor. - bool value_is_tensor = true; - if (py::isinstance(value_obj) || - py::isinstance(value_obj) || - py::isinstance(value_obj)) { - value_is_tensor = false; - } - // 2. Call op set_value to speed up if the condition is met, // otherwise call TensorToPyArray. // TODO(liym27): Try not to call TensorToPyArray because it always // copys data to cpu place, which reduces performance. - if (parse_index && stride_is_1 && value_is_tensor) { - framework::AttributeMap attrs = { - {"axes", axes}, {"starts", starts}, {"ends", ends}}; + if (parse_index && value_is_tensor) { + std::vector axes, starts, ends, steps, decrease_axis, + infer_flags; + ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, + &steps, &decrease_axis, &infer_flags); + + framework::AttributeMap attrs = {{"axes", axes}, + {"starts", starts}, + {"ends", ends}, + {"steps", steps}}; imperative::NameVarBaseMap ins = {{"Input", {self}}}; imperative::NameVarBaseMap outs = {{"Out", {self}}}; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8ed5add554..fd8a39259d 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1866,6 +1866,8 @@ class Variable(object): axes = [] starts = [] ends = [] + steps = [] + max_integer = sys.maxsize def replace_ellipsis(item): @@ -1877,7 +1879,12 @@ class Variable(object): # var[0, ..., 1:2] -> var[0, :, :, 1:2] item = list(item) - ell_count = item.count(Ellipsis) + + # Remove Variable to skip bug when counting Ellipsis + item_remove_var = [ + ele for ele in item if not isinstance(ele, Variable) + ] + ell_count = item_remove_var.count(Ellipsis) if ell_count == 0: return item elif ell_count > 1: @@ -1905,23 +1912,47 @@ class Variable(object): if start is None and end is None and step is None: continue - start = 0 if start is None else start step = 1 if step is None else step - # TODO: support cases when step != 1 - if step != 1: + # TODO: support cases when step < 1 + if not isinstance(step, Variable) and step == 0: raise ValueError( - "When assign a value to a paddle.Tensor, only support step is 1, " + "When assign a value to a paddle.Tensor, step can not be 0, " "but received step is {}.".format(step)) - end = max_integer if end is None else end + + if isinstance(step, Variable) and (start is None or + end is None): + raise ValueError( + "When assign a value to a paddle.Tensor, it's not supported that " + "the start or end is None when the type of step is paddle.Tensor." + ) + + if start is None: + start = 0 if step > 0 else max_integer + + if end is None: + end = max_integer if step > 0 else (0 - max_integer) else: start = slice_item end = slice_item + 1 if slice_item != -1 else max_integer + step = 1 axes.append(dim) starts.append(start) ends.append(end) + steps.append(step) - attrs = {'axes': axes, 'starts': starts, 'ends': ends} + attrs = {'axes': axes, 'starts': starts, 'ends': ends, 'steps': steps} + + from .layers import utils + if utils._contain_var(starts): + inputs['StartsTensorList'] = utils._convert_to_tensor_list(starts) + del attrs['starts'] + if utils._contain_var(ends): + inputs['EndsTensorList'] = utils._convert_to_tensor_list(ends) + del attrs['ends'] + if utils._contain_var(steps): + inputs['StepsTensorList'] = utils._convert_to_tensor_list(steps) + del attrs['steps'] # 2. Parse value dtype = self.dtype @@ -1968,6 +1999,7 @@ class Variable(object): self.block.append_op( type="set_value", inputs=inputs, outputs={'Out': self}, attrs=attrs) + return self diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 79b270f162..23dac41f64 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -27,10 +27,13 @@ class TestSetValueBase(unittest.TestCase): paddle.enable_static() self.set_dtype() self.set_value() - self.shape = [2, 3, 4] + self.set_shape() self.data = np.ones(self.shape).astype(self.dtype) self.program = paddle.static.Program() + def set_shape(self): + self.shape = [2, 3, 4] + def set_value(self): self.value = 6 @@ -59,7 +62,8 @@ class TestSetValueApi(TestSetValueBase): self.data, out)) -# 1. Test different type of item: int, python slice, Ellipsis +# 1. Test different type of item: int, Python slice, Paddle Tensor +# 1.1 item is int class TestSetValueItemInt(TestSetValueApi): def _call_setitem(self, x): x[0] = self.value @@ -68,6 +72,8 @@ class TestSetValueItemInt(TestSetValueApi): self.data[0] = self.value +# 1.2 item is slice +# 1.2.1 step is 1 class TestSetValueItemSlice(TestSetValueApi): def _call_setitem(self, x): x[0:2] = self.value @@ -100,6 +106,102 @@ class TestSetValueItemSlice4(TestSetValueApi): self.data[0:, 1:2, :] = self.value +# 1.2.2 step > 1 +class TestSetValueItemSliceStep(TestSetValueApi): + def set_shape(self): + self.shape = [5, 5, 5] + + def _call_setitem(self, x): + x[0:2:2] = self.value + + def _get_answer(self): + self.data[0:2:2] = self.value + + +class TestSetValueItemSliceStep2(TestSetValueApi): + def set_shape(self): + self.shape = [7, 5, 5] + + def _call_setitem(self, x): + x[0:-1:3] = self.value + + def _get_answer(self): + self.data[0:-1:3] = self.value + + +class TestSetValueItemSliceStep3(TestSetValueApi): + def _call_setitem(self, x): + x[0:-1, 0:2, ::2] = self.value + + def _get_answer(self): + self.data[0:-1, 0:2, ::2] = self.value + + +class TestSetValueItemSliceStep4(TestSetValueApi): + def _call_setitem(self, x): + x[0:, 1:2:2, :] = self.value + + def _get_answer(self): + self.data[0:, 1:2:2, :] = self.value + + +# 1.2.3 step < 0 +class TestSetValueItemSliceNegetiveStep(TestSetValueApi): + def set_shape(self): + self.shape = [5, 2] + + def set_value(self): + self.value = np.array([3, 4]) + + def _call_setitem(self, x): + x[5:2:-1] = self.value + + def _get_answer(self): + self.data[5:2:-1] = self.value + + +class TestSetValueItemSliceNegetiveStep2(TestSetValueApi): + def set_shape(self): + self.shape = [5] + + def set_value(self): + self.value = np.array([3, 4]) + + def _call_setitem(self, x): + x[1::-1] = self.value + + def _get_answer(self): + self.data[1::-1] = self.value + + +class TestSetValueItemSliceNegetiveStep3(TestSetValueApi): + def set_shape(self): + self.shape = [3] + + def set_value(self): + self.value = np.array([3, 4, 5]) + + def _call_setitem(self, x): + x[::-1] = self.value + + def _get_answer(self): + self.data[::-1] = self.value + + +class TestSetValueItemSliceNegetiveStep4(TestSetValueApi): + def set_shape(self): + self.shape = [3, 4, 5] + + def _call_setitem(self, x): + x[2:0:-1, 0:2, ::-1] = self.value + + def _get_answer(self): + self.data[2:0:-1, 0:2, ::-1] = self.value + + +# 1.3 item is Ellipsis + + class TestSetValueItemEllipsis1(TestSetValueApi): def _call_setitem(self, x): x[0:, ..., 1:] = self.value @@ -132,6 +234,69 @@ class TestSetValueItemEllipsis4(TestSetValueApi): self.data[...] = self.value +# 1.4 item is Paddle Tensor +class TestSetValueItemTensor(TestSetValueApi): + def _call_setitem(self, x): + zero = paddle.full([1], 0, dtype="int32") + x[zero] = self.value + + def _get_answer(self): + self.data[0] = self.value + + +class TestSetValueItemTensor2(TestSetValueApi): + def _call_setitem(self, x): + zero = paddle.full([1], 0, dtype="int32") + two = paddle.full([1], 2, dtype="int64") + x[zero:two] = self.value + + def _get_answer(self): + self.data[0:2] = self.value + + +class TestSetValueItemTensor3(TestSetValueApi): + def _call_setitem(self, x): + zero = paddle.full([1], 0, dtype="int32") + two = paddle.full([1], 2, dtype="int64") + x[zero:-1, 0:two] = self.value + + def _get_answer(self): + self.data[0:-1, 0:2] = self.value + + +class TestSetValueItemTensor4(TestSetValueApi): + def _call_setitem(self, x): + zero = paddle.full([1], 0, dtype="int32") + two = paddle.full([1], 2, dtype="int64") + x[0:-1, zero:2, 0:6:two] = self.value + + def _get_answer(self): + self.data[0:-1, 0:2, ::2] = self.value + + +class TestSetValueItemTensor5(TestSetValueApi): + def _call_setitem(self, x): + zero = paddle.full([1], 0, dtype="int32") + two = paddle.full([1], 2, dtype="int64") + x[zero:, 1:2:two, :] = self.value + + def _get_answer(self): + self.data[0:, 1:2:2, :] = self.value + + +class TestSetValueItemTensor6(TestSetValueApi): + def set_shape(self): + self.shape = [3, 4, 5] + + def _call_setitem(self, x): + minus1 = paddle.full([1], -1, dtype="int32") + zero = paddle.full([1], 0, dtype="int32") + x[2:zero:minus1, 0:2, 10:-6:minus1] = self.value + + def _get_answer(self): + self.data[2:0:-1, 0:2, ::-1] = self.value + + # 2. Test different type of value: int, float, numpy.ndarray, Tensor # 2.1 value is int32, int64, float32, float64, bool @@ -526,15 +691,19 @@ class TestError(TestSetValueBase): y[0] = 1 def _step_error(self): - with self.assertRaisesRegexp(ValueError, "only support step is 1"): + with self.assertRaisesRegexp(ValueError, "step can not be 0"): x = paddle.ones(shape=self.shape, dtype=self.dtype) - x[0:1:2] = self.value + x[0:1:0] = self.value def _ellipsis_error(self): with self.assertRaisesRegexp( IndexError, "An index can only have a single ellipsis"): x = paddle.ones(shape=self.shape, dtype=self.dtype) x[..., ...] = self.value + with self.assertRaisesRegexp(ValueError, "the start or end is None"): + x = paddle.ones(shape=self.shape, dtype=self.dtype) + one = paddle.ones([1]) + x[::one] = self.value def _broadcast_mismatch(self): program = paddle.static.Program() -- GitLab