未验证 提交 5b367dab 编写于 作者: L liym27 提交者: GitHub

[static setitem] Support the index is Tensor; step>1; step<0 .(#30949)

* [static setitem] support the index step > 1. tensor_a[::3] = value

* [static setitem] support the index step < 0. Eg: tensor_a[::-3] = value

* [static setitem] support the index is Tensor. eg: tensor_a[tensor_3:0:-1] = value

* Add op version.
上级 eb3050fa
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/set_value_op.h" #include "paddle/fluid/operators/set_value_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -60,18 +60,52 @@ class SetValue : public framework::OperatorWithKernel { ...@@ -60,18 +60,52 @@ class SetValue : public framework::OperatorWithKernel {
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace()); ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
var_name == "StepsTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class SetValueMaker : public framework::OpProtoAndCheckerMaker { class SetValueMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
// Input
AddInput("Input", "(Tensor) Input tensor of set_value operator."); AddInput("Input", "(Tensor) Input tensor of set_value operator.");
AddInput("ValueTensor", "(Tensor) Value tensor of set_value operator.") AddInput("ValueTensor", "(Tensor) Value tensor of set_value operator.")
.AsDispensable(); .AsDispensable();
AddInput("StartsTensorList",
"(vector<Tensor<int32>>, optional) If provided, set_value will "
"use this. The shape of the tensor in vector must be [1]."
"It has higher priority compare with attr(starts).")
.AsDuplicable()
.AsDispensable();
AddInput("EndsTensorList",
"(vector<Tensor<int32>>, optional) If provided, set_value will "
"use this. The shape of the tensor in vector must BE [1]."
"It has higher priority compare with attr(ends).")
.AsDuplicable()
.AsDispensable();
AddInput("StepsTensorList",
"(vector<Tensor<int32>>, optional) If provided, set_value will "
"use this. The shape of the tensor in vector must BE [1]."
"It has higher priority compare with attr(steps).")
.AsDuplicable()
.AsDispensable();
// Output
AddOutput("Out", AddOutput("Out",
"(Tensor) Output tensor of set_value operator. The output is the " "(Tensor) Output tensor of set_value operator. The output is the "
"same Tensor as input"); "same Tensor as input");
// Attr
AddAttr<int>("dtype", "data type of input.") AddAttr<int>("dtype", "data type of input.")
.InEnum( .InEnum(
{framework::proto::VarType::BOOL, framework::proto::VarType::INT32, {framework::proto::VarType::BOOL, framework::proto::VarType::INT32,
...@@ -82,20 +116,25 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { ...@@ -82,20 +116,25 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
"axes", "(list<int64_t>) Axes that `starts` and `ends` apply to."); "axes", "(list<int64_t>) Axes that `starts` and `ends` apply to.");
AddAttr<std::vector<int64_t>>( AddAttr<std::vector<int64_t>>(
"starts", "starts",
"(list<int64_t>) Starting indices of corresponding axis in `axes`"); "(list<int64_t>) Starting indices of corresponding axis in `axes`.")
.SetDefault({});
AddAttr<std::vector<int64_t>>( AddAttr<std::vector<int64_t>>(
"ends", "ends",
"(list<int64_t>) Ending indices of corresponding axis in `axes`."); "(list<int64_t>) Ending indices of corresponding axis in `axes`.")
.SetDefault({});
AddAttr<std::vector<int64_t>>(
"steps", "(list<int64_t>) Stride step from the start to the end.")
.SetDefault({});
AddAttr<std::vector<int>>("bool_values", "store the bool values") AddAttr<std::vector<int>>("bool_values", "Store the bool values.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<float>>("fp32_values", "store the float32 values") AddAttr<std::vector<float>>("fp32_values", "Store the float32 values.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int32 values") AddAttr<std::vector<int>>("int32_values", "Store the int32 values.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int64_t>>("int64_values", "store the int64 values") AddAttr<std::vector<int64_t>>("int64_values", "Store the int64 values.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<double>>("fp64_values", "store the float64 values") AddAttr<std::vector<double>>("fp64_values", "Store the float64 values.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.") AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.")
...@@ -121,3 +160,30 @@ REGISTER_OP_CPU_KERNEL( ...@@ -121,3 +160,30 @@ REGISTER_OP_CPU_KERNEL(
ops::SetValueKernel<paddle::platform::CPUDeviceContext, float>, ops::SetValueKernel<paddle::platform::CPUDeviceContext, float>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, double>, ops::SetValueKernel<paddle::platform::CPUDeviceContext, double>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, bool>); ops::SetValueKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_VERSION(set_value)
.AddCheckpoint(
R"ROC(
Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorList] and 1 attribute [steps].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("StartsTensorList",
"If provided, set_value will use this.The shape of the "
"tensor in vector must be [1]. It has higher priority "
"compare with attr(starts).")
.NewInput("EndsTensorList",
"If provided, set_value will use this.The shape of the "
"tensor in vector must be [1]. It has higher priority "
"compare with attr(ends).")
.NewInput("StepsTensorList",
"If provided, set_value will use this.The shape of the "
"tensor in vector must be [1]. It has higher priority "
"compare with attr(steps).")
.ModifyAttr("starts",
"Starting indices of corresponding axis in `axes`.",
std::vector<int64_t>{})
.ModifyAttr("ends",
"Ending indices of corresponding axis in `axes`.",
std::vector<int64_t>{})
.NewAttr("steps", "Stride step from the start to the end.",
std::vector<int64_t>{}));
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/assign_value_op.h" #include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -58,26 +59,70 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) { ...@@ -58,26 +59,70 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) {
return value_name; return value_name;
} }
inline framework::DDim GetSliceDims(const framework::DDim in_dims, inline void CheckAndUpdateSlice(const framework::DDim in_dims,
const std::vector<int64_t> axes, const std::vector<int64_t> axes,
const std::vector<int64_t> starts, std::vector<int64_t>* starts,
const std::vector<int64_t> ends) { std::vector<int64_t>* ends,
framework::DDim slice_dims(in_dims); std::vector<int64_t>* steps) {
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
int64_t axis = axes[i]; int64_t axis = axes[i];
int64_t dim_value = in_dims[axis]; int64_t dim_value = in_dims[axis];
int64_t start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; int64_t start =
int64_t end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
int64_t end = (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
start = std::max(start, static_cast<int64_t>(0)); start = std::max(start, static_cast<int64_t>(0));
end = std::min(end, dim_value); end = std::min(end, dim_value);
PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument( int64_t step = (*steps)[i];
"end should greater than start, but " PADDLE_ENFORCE_NE(
"received end = %d, start = %d", step, 0, platform::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<int64_t>(0));
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end, start)); end, start));
slice_dims[axis] = end - start; } else {
// NOTE(liym27): When step < 0, start should less and equal to dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<int64_t>(-1));
PADDLE_ENFORCE_GT(
start, end,
platform::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start, end));
}
(*starts)[i] = start;
(*ends)[i] = end;
}
}
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<int64_t> axes,
const std::vector<int64_t> starts,
const std::vector<int64_t> ends,
const std::vector<int64_t> steps) {
framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
int64_t axis = axes[i];
int64_t start = starts[i];
int64_t end = ends[i];
int64_t step = steps[i];
if (step > 0) {
slice_dims[axis] = (end - start + step - 1) / step;
} else {
slice_dims[axis] = (end - start + step + 1) / step;
}
} }
return slice_dims; return slice_dims;
} }
...@@ -120,19 +165,36 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -120,19 +165,36 @@ class SetValueKernel : public framework::OpKernel<T> {
template <size_t D> template <size_t D>
void SetValueCompute(const framework::ExecutionContext& ctx) const { void SetValueCompute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::LoDTensor>("Input"); auto* in = ctx.Input<framework::LoDTensor>("Input");
auto* value_tensor = ctx.Input<framework::LoDTensor>("ValueTensor");
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
auto starts_tensor_list =
ctx.MultiInput<framework::Tensor>("StartsTensorList");
auto ends_tensor_list = ctx.MultiInput<framework::Tensor>("EndsTensorList");
auto steps_tensor_list =
ctx.MultiInput<framework::Tensor>("StepsTensorList");
auto dtype = auto dtype =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto axes = ctx.Attr<std::vector<int64_t>>("axes"); auto axes = ctx.Attr<std::vector<int64_t>>("axes");
auto starts = ctx.Attr<std::vector<int64_t>>("starts"); auto starts = ctx.Attr<std::vector<int64_t>>("starts");
auto ends = ctx.Attr<std::vector<int64_t>>("ends"); auto ends = ctx.Attr<std::vector<int64_t>>("ends");
auto steps = ctx.Attr<std::vector<int64_t>>("steps");
auto shape = ctx.Attr<std::vector<int64_t>>("shape"); auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto* value_tensor = ctx.Input<framework::LoDTensor>("ValueTensor");
if (!starts_tensor_list.empty()) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
if (!ends_tensor_list.empty()) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
if (!steps_tensor_list.empty()) {
steps = GetDataFromTensorList<int64_t>(steps_tensor_list);
}
auto in_dims = in->dims(); auto in_dims = in->dims();
auto value_dims = framework::make_ddim(shape); CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends); auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto& eigen_place = auto& eigen_place =
...@@ -160,46 +222,37 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -160,46 +222,37 @@ class SetValueKernel : public framework::OpKernel<T> {
auto slice_e = framework::EigenTensor<T, D>::From(slice_t, slice_dims); auto slice_e = framework::EigenTensor<T, D>::From(slice_t, slice_dims);
// Step 1: Set the value of out at `_index` to zero // Step 1: Set the value of out at `_index` to zero
// - Step 1.1 Get a slice tensor from out slice_e.device(eigen_place) = slice_e.constant(T(0));
Eigen::array<int64_t, D> offsets, extents;
Eigen::array<std::pair<int64_t, int64_t>, D> paddings; auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) { for (size_t i = 0; i < D; ++i) {
offsets[i] = 0; starts_indices[i] = 0;
extents[i] = slice_dims[i]; ends_indices[i] = slice_dims[i];
strides_indices[i] = 1;
} }
int64_t start; for (size_t i = 0; i < axes.size(); i++) {
for (size_t i = 0; i < axes.size(); ++i) { int axis_index = axes[i];
start = starts[i] < 0 ? (starts[i] + in_dims[axes[i]]) : starts[i]; starts_indices[axis_index] = starts[i];
start = std::max(start, static_cast<int64_t>(0)); ends_indices[axis_index] = ends[i];
offsets[axes[i]] = start; strides_indices[axis_index] = steps[i];
}
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - slice_dims[i]) - offsets[i];
} }
slice_e.device(eigen_place) = out_e.slice(offsets, extents); out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// - Step 1.2 Get paded tensor by padding 0 to slice tensor
pad_e.device(eigen_place) = slice_e.pad(paddings, T(0));
// - Step 1.3 Set 0 at `_index` of out tensor
out_e.device(eigen_place) = out_e - pad_e;
// Step 2: Set a tensor with the same shape as out tensor. And its data at // Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero // '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
// - Step 2.1 Set the data of slice tensor to 0
slice_e.device(eigen_place) = slice_e.constant(T(0));
// - Step 2.2 Set slice tensor with value
if (value_tensor != nullptr) { if (value_tensor != nullptr) {
// ElementwiseComputeEx can do broadcasting // ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_t, value_tensor, -1, SubFunctor<T>(), &slice_t); ctx, &slice_t, value_tensor, -1, SubFunctor<T>(), &slice_t);
} else { } else {
Tensor value_t(dtype); Tensor value_t(dtype);
auto value_dims = framework::make_ddim(shape);
value_t.mutable_data<T>(value_dims, place); value_t.mutable_data<T>(value_dims, place);
auto value_name = GetValueName(dtype); auto value_name = GetValueName(dtype);
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx); CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
...@@ -208,8 +261,10 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -208,8 +261,10 @@ class SetValueKernel : public framework::OpKernel<T> {
ctx, &slice_t, &value_t, -1, SubFunctor<T>(), &slice_t); ctx, &slice_t, &value_t, -1, SubFunctor<T>(), &slice_t);
} }
// - Step 2.3 Pad slice tensor with 0 // - Step 2.2 Pad slice tensor with 0
pad_e.device(eigen_place) = slice_e.pad(paddings, T(0)); pad_e.device(eigen_place) = pad_e.constant(T(0));
pad_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// Step 3: Set out tensor with value_tensor // Step 3: Set out tensor with value_tensor
out_e.device(eigen_place) = out_e - pad_e; out_e.device(eigen_place) = out_e - pad_e;
......
...@@ -587,8 +587,16 @@ void BindImperative(py::module *m_ptr) { ...@@ -587,8 +587,16 @@ void BindImperative(py::module *m_ptr) {
? PyTuple_Pack(1, _index.ptr()) ? PyTuple_Pack(1, _index.ptr())
: _index.ptr(); : _index.ptr();
// 1. Check argumnets // 1. Check argumnets
// 1.1 Check whether _index can be parsed. // 1.1 Check whether value obj is a tensor.
bool value_is_tensor = true;
bool parse_index = true; bool parse_index = true;
if (py::isinstance<py::array>(value_obj) ||
py::isinstance<py::int_>(value_obj) ||
py::isinstance<py::float_>(value_obj)) {
value_is_tensor = false;
}
// 1.2 Check whether _index can be parsed.
const int size = PyTuple_GET_SIZE(index_ptr); const int size = PyTuple_GET_SIZE(index_ptr);
for (int dim = 0; dim < size; ++dim) { for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index_ptr, dim); PyObject *slice_item = PyTuple_GetItem(index_ptr, dim);
...@@ -598,34 +606,20 @@ void BindImperative(py::module *m_ptr) { ...@@ -598,34 +606,20 @@ void BindImperative(py::module *m_ptr) {
} }
} }
// 1.2 Check whether stride is 1.
std::vector<int> axes, starts, ends, strides, decrease_axis,
infer_flags;
bool stride_is_1 = true;
if (parse_index) {
ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends,
&strides, &decrease_axis, &infer_flags);
stride_is_1 =
std::all_of(strides.cbegin(), strides.cend(),
[](int64_t stride) { return stride == 1; });
}
// 1.3 Check whether value obj is a tensor.
bool value_is_tensor = true;
if (py::isinstance<py::array>(value_obj) ||
py::isinstance<py::int_>(value_obj) ||
py::isinstance<py::float_>(value_obj)) {
value_is_tensor = false;
}
// 2. Call op set_value to speed up if the condition is met, // 2. Call op set_value to speed up if the condition is met,
// otherwise call TensorToPyArray. // otherwise call TensorToPyArray.
// TODO(liym27): Try not to call TensorToPyArray because it always // TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance. // copys data to cpu place, which reduces performance.
if (parse_index && stride_is_1 && value_is_tensor) { if (parse_index && value_is_tensor) {
framework::AttributeMap attrs = { std::vector<int> axes, starts, ends, steps, decrease_axis,
{"axes", axes}, {"starts", starts}, {"ends", ends}}; infer_flags;
ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends,
&steps, &decrease_axis, &infer_flags);
framework::AttributeMap attrs = {{"axes", axes},
{"starts", starts},
{"ends", ends},
{"steps", steps}};
imperative::NameVarBaseMap ins = {{"Input", {self}}}; imperative::NameVarBaseMap ins = {{"Input", {self}}};
imperative::NameVarBaseMap outs = {{"Out", {self}}}; imperative::NameVarBaseMap outs = {{"Out", {self}}};
......
...@@ -1866,6 +1866,8 @@ class Variable(object): ...@@ -1866,6 +1866,8 @@ class Variable(object):
axes = [] axes = []
starts = [] starts = []
ends = [] ends = []
steps = []
max_integer = sys.maxsize max_integer = sys.maxsize
def replace_ellipsis(item): def replace_ellipsis(item):
...@@ -1877,7 +1879,12 @@ class Variable(object): ...@@ -1877,7 +1879,12 @@ class Variable(object):
# var[0, ..., 1:2] -> var[0, :, :, 1:2] # var[0, ..., 1:2] -> var[0, :, :, 1:2]
item = list(item) item = list(item)
ell_count = item.count(Ellipsis)
# Remove Variable to skip bug when counting Ellipsis
item_remove_var = [
ele for ele in item if not isinstance(ele, Variable)
]
ell_count = item_remove_var.count(Ellipsis)
if ell_count == 0: if ell_count == 0:
return item return item
elif ell_count > 1: elif ell_count > 1:
...@@ -1905,23 +1912,47 @@ class Variable(object): ...@@ -1905,23 +1912,47 @@ class Variable(object):
if start is None and end is None and step is None: if start is None and end is None and step is None:
continue continue
start = 0 if start is None else start
step = 1 if step is None else step step = 1 if step is None else step
# TODO: support cases when step != 1 # TODO: support cases when step < 1
if step != 1: if not isinstance(step, Variable) and step == 0:
raise ValueError( raise ValueError(
"When assign a value to a paddle.Tensor, only support step is 1, " "When assign a value to a paddle.Tensor, step can not be 0, "
"but received step is {}.".format(step)) "but received step is {}.".format(step))
end = max_integer if end is None else end
if isinstance(step, Variable) and (start is None or
end is None):
raise ValueError(
"When assign a value to a paddle.Tensor, it's not supported that "
"the start or end is None when the type of step is paddle.Tensor."
)
if start is None:
start = 0 if step > 0 else max_integer
if end is None:
end = max_integer if step > 0 else (0 - max_integer)
else: else:
start = slice_item start = slice_item
end = slice_item + 1 if slice_item != -1 else max_integer end = slice_item + 1 if slice_item != -1 else max_integer
step = 1
axes.append(dim) axes.append(dim)
starts.append(start) starts.append(start)
ends.append(end) ends.append(end)
steps.append(step)
attrs = {'axes': axes, 'starts': starts, 'ends': ends, 'steps': steps}
attrs = {'axes': axes, 'starts': starts, 'ends': ends} from .layers import utils
if utils._contain_var(starts):
inputs['StartsTensorList'] = utils._convert_to_tensor_list(starts)
del attrs['starts']
if utils._contain_var(ends):
inputs['EndsTensorList'] = utils._convert_to_tensor_list(ends)
del attrs['ends']
if utils._contain_var(steps):
inputs['StepsTensorList'] = utils._convert_to_tensor_list(steps)
del attrs['steps']
# 2. Parse value # 2. Parse value
dtype = self.dtype dtype = self.dtype
...@@ -1968,6 +1999,7 @@ class Variable(object): ...@@ -1968,6 +1999,7 @@ class Variable(object):
self.block.append_op( self.block.append_op(
type="set_value", inputs=inputs, outputs={'Out': self}, attrs=attrs) type="set_value", inputs=inputs, outputs={'Out': self}, attrs=attrs)
return self return self
......
...@@ -27,10 +27,13 @@ class TestSetValueBase(unittest.TestCase): ...@@ -27,10 +27,13 @@ class TestSetValueBase(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
self.set_dtype() self.set_dtype()
self.set_value() self.set_value()
self.shape = [2, 3, 4] self.set_shape()
self.data = np.ones(self.shape).astype(self.dtype) self.data = np.ones(self.shape).astype(self.dtype)
self.program = paddle.static.Program() self.program = paddle.static.Program()
def set_shape(self):
self.shape = [2, 3, 4]
def set_value(self): def set_value(self):
self.value = 6 self.value = 6
...@@ -59,7 +62,8 @@ class TestSetValueApi(TestSetValueBase): ...@@ -59,7 +62,8 @@ class TestSetValueApi(TestSetValueBase):
self.data, out)) self.data, out))
# 1. Test different type of item: int, python slice, Ellipsis # 1. Test different type of item: int, Python slice, Paddle Tensor
# 1.1 item is int
class TestSetValueItemInt(TestSetValueApi): class TestSetValueItemInt(TestSetValueApi):
def _call_setitem(self, x): def _call_setitem(self, x):
x[0] = self.value x[0] = self.value
...@@ -68,6 +72,8 @@ class TestSetValueItemInt(TestSetValueApi): ...@@ -68,6 +72,8 @@ class TestSetValueItemInt(TestSetValueApi):
self.data[0] = self.value self.data[0] = self.value
# 1.2 item is slice
# 1.2.1 step is 1
class TestSetValueItemSlice(TestSetValueApi): class TestSetValueItemSlice(TestSetValueApi):
def _call_setitem(self, x): def _call_setitem(self, x):
x[0:2] = self.value x[0:2] = self.value
...@@ -100,6 +106,102 @@ class TestSetValueItemSlice4(TestSetValueApi): ...@@ -100,6 +106,102 @@ class TestSetValueItemSlice4(TestSetValueApi):
self.data[0:, 1:2, :] = self.value self.data[0:, 1:2, :] = self.value
# 1.2.2 step > 1
class TestSetValueItemSliceStep(TestSetValueApi):
def set_shape(self):
self.shape = [5, 5, 5]
def _call_setitem(self, x):
x[0:2:2] = self.value
def _get_answer(self):
self.data[0:2:2] = self.value
class TestSetValueItemSliceStep2(TestSetValueApi):
def set_shape(self):
self.shape = [7, 5, 5]
def _call_setitem(self, x):
x[0:-1:3] = self.value
def _get_answer(self):
self.data[0:-1:3] = self.value
class TestSetValueItemSliceStep3(TestSetValueApi):
def _call_setitem(self, x):
x[0:-1, 0:2, ::2] = self.value
def _get_answer(self):
self.data[0:-1, 0:2, ::2] = self.value
class TestSetValueItemSliceStep4(TestSetValueApi):
def _call_setitem(self, x):
x[0:, 1:2:2, :] = self.value
def _get_answer(self):
self.data[0:, 1:2:2, :] = self.value
# 1.2.3 step < 0
class TestSetValueItemSliceNegetiveStep(TestSetValueApi):
def set_shape(self):
self.shape = [5, 2]
def set_value(self):
self.value = np.array([3, 4])
def _call_setitem(self, x):
x[5:2:-1] = self.value
def _get_answer(self):
self.data[5:2:-1] = self.value
class TestSetValueItemSliceNegetiveStep2(TestSetValueApi):
def set_shape(self):
self.shape = [5]
def set_value(self):
self.value = np.array([3, 4])
def _call_setitem(self, x):
x[1::-1] = self.value
def _get_answer(self):
self.data[1::-1] = self.value
class TestSetValueItemSliceNegetiveStep3(TestSetValueApi):
def set_shape(self):
self.shape = [3]
def set_value(self):
self.value = np.array([3, 4, 5])
def _call_setitem(self, x):
x[::-1] = self.value
def _get_answer(self):
self.data[::-1] = self.value
class TestSetValueItemSliceNegetiveStep4(TestSetValueApi):
def set_shape(self):
self.shape = [3, 4, 5]
def _call_setitem(self, x):
x[2:0:-1, 0:2, ::-1] = self.value
def _get_answer(self):
self.data[2:0:-1, 0:2, ::-1] = self.value
# 1.3 item is Ellipsis
class TestSetValueItemEllipsis1(TestSetValueApi): class TestSetValueItemEllipsis1(TestSetValueApi):
def _call_setitem(self, x): def _call_setitem(self, x):
x[0:, ..., 1:] = self.value x[0:, ..., 1:] = self.value
...@@ -132,6 +234,69 @@ class TestSetValueItemEllipsis4(TestSetValueApi): ...@@ -132,6 +234,69 @@ class TestSetValueItemEllipsis4(TestSetValueApi):
self.data[...] = self.value self.data[...] = self.value
# 1.4 item is Paddle Tensor
class TestSetValueItemTensor(TestSetValueApi):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
x[zero] = self.value
def _get_answer(self):
self.data[0] = self.value
class TestSetValueItemTensor2(TestSetValueApi):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
x[zero:two] = self.value
def _get_answer(self):
self.data[0:2] = self.value
class TestSetValueItemTensor3(TestSetValueApi):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
x[zero:-1, 0:two] = self.value
def _get_answer(self):
self.data[0:-1, 0:2] = self.value
class TestSetValueItemTensor4(TestSetValueApi):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
x[0:-1, zero:2, 0:6:two] = self.value
def _get_answer(self):
self.data[0:-1, 0:2, ::2] = self.value
class TestSetValueItemTensor5(TestSetValueApi):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
x[zero:, 1:2:two, :] = self.value
def _get_answer(self):
self.data[0:, 1:2:2, :] = self.value
class TestSetValueItemTensor6(TestSetValueApi):
def set_shape(self):
self.shape = [3, 4, 5]
def _call_setitem(self, x):
minus1 = paddle.full([1], -1, dtype="int32")
zero = paddle.full([1], 0, dtype="int32")
x[2:zero:minus1, 0:2, 10:-6:minus1] = self.value
def _get_answer(self):
self.data[2:0:-1, 0:2, ::-1] = self.value
# 2. Test different type of value: int, float, numpy.ndarray, Tensor # 2. Test different type of value: int, float, numpy.ndarray, Tensor
# 2.1 value is int32, int64, float32, float64, bool # 2.1 value is int32, int64, float32, float64, bool
...@@ -526,15 +691,19 @@ class TestError(TestSetValueBase): ...@@ -526,15 +691,19 @@ class TestError(TestSetValueBase):
y[0] = 1 y[0] = 1
def _step_error(self): def _step_error(self):
with self.assertRaisesRegexp(ValueError, "only support step is 1"): with self.assertRaisesRegexp(ValueError, "step can not be 0"):
x = paddle.ones(shape=self.shape, dtype=self.dtype) x = paddle.ones(shape=self.shape, dtype=self.dtype)
x[0:1:2] = self.value x[0:1:0] = self.value
def _ellipsis_error(self): def _ellipsis_error(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
IndexError, "An index can only have a single ellipsis"): IndexError, "An index can only have a single ellipsis"):
x = paddle.ones(shape=self.shape, dtype=self.dtype) x = paddle.ones(shape=self.shape, dtype=self.dtype)
x[..., ...] = self.value x[..., ...] = self.value
with self.assertRaisesRegexp(ValueError, "the start or end is None"):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
one = paddle.ones([1])
x[::one] = self.value
def _broadcast_mismatch(self): def _broadcast_mismatch(self):
program = paddle.static.Program() program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册