未验证 提交 f775bfc1 编写于 作者: Z zyfncg 提交者: GitHub

Support setitem by None index (#34442)

* Support setitem by None index

* remove unreachable code

* Add Checkpoint for set_value_op because add a new attribute
上级 30416052
......@@ -127,6 +127,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int64_t>>("decrease_axes",
"(list<int>) The axes to decrease.")
.SetDefault({});
AddAttr<std::vector<int64_t>>("none_axes", "(list<int>) The axes to none.")
.SetDefault({});
AddAttr<std::vector<int>>("bool_values", "Store the bool values.")
.SetDefault({});
......@@ -247,4 +249,10 @@ Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorLi
Upgrade set_value, add 1 attribute [decrease_axes].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"decrease_axes", "The axes to decrease.", std::vector<int64_t>{}));
"decrease_axes", "The axes to decrease.", std::vector<int64_t>{}))
.AddCheckpoint(
R"ROC(
Upgrade set_value, add 1 attribute [none_axes].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"none_axes", "The axes with none index.", std::vector<int64_t>{}));
......@@ -60,6 +60,47 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) {
return value_name;
}
// check whether the tensor with dimension of second can assign to the
// tensor with dimension of first
inline void CheckIsDimsMatch(const framework::DDim first,
const framework::DDim second) {
int ignore_axis1 = 0, ignore_axis2 = 0;
for (; ignore_axis1 < first.size(); ++ignore_axis1) {
if (first[ignore_axis1] != 1) {
break;
}
}
for (; ignore_axis2 < second.size(); ++ignore_axis2) {
if (second[ignore_axis2] != 1) {
break;
}
}
if (second.size() == ignore_axis2) {
// second tensor has only one value
return;
}
if (first.size() - ignore_axis1 >= second.size() - ignore_axis2) {
auto idx1 = first.size() - 1;
auto idx2 = second.size() - 1;
bool is_match = true;
for (; idx2 >= ignore_axis2; idx2--) {
if (first[idx1--] != second[idx2] && second[idx2] != 1) {
is_match = false;
break;
}
}
if (is_match) {
return;
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"The shape of tensor assigned value must match the shape "
"of target shape: %d, but now shape is %d.",
second.to_str(), first.to_str()));
}
template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> {
public:
......@@ -113,6 +154,7 @@ class SetValueKernel : public framework::OpKernel<T> {
auto steps = ctx.Attr<std::vector<int64_t>>("steps");
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto decrease_axes = ctx.Attr<std::vector<int64_t>>("decrease_axes");
auto none_axes = ctx.Attr<std::vector<int64_t>>("none_axes");
auto dtype = in->type();
if (!starts_tensor_list.empty()) {
......@@ -130,6 +172,32 @@ class SetValueKernel : public framework::OpKernel<T> {
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none;
size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() &&
none_axes[none_axes_cur] <= i) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
if (decrease_axes_cur < decrease_axes.size() &&
decrease_axes[decrease_axes_cur] == i) {
decrease_axes_cur++;
} else {
slice_dims_with_none.push_back(slice_dims[i]);
}
}
while (none_axes_cur < none_axes.size()) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
slice_dims_for_assign = framework::make_ddim(slice_dims_with_none);
}
auto place = ctx.GetPlace();
auto& eigen_place =
*ctx.template device_context<DeviceContext>().eigen_device();
......@@ -194,14 +262,17 @@ class SetValueKernel : public framework::OpKernel<T> {
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor.Resize(decrease_slice_dims);
slice_tensor.Resize(slice_dims_for_assign);
if (value_tensor != nullptr) {
CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims());
// ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor);
} else {
Tensor value_t(dtype);
auto value_dims = framework::make_ddim(shape);
CheckIsDimsMatch(slice_dims_for_assign, value_dims);
value_t.mutable_data<T>(value_dims, place);
auto value_name = GetValueName(dtype);
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
......
......@@ -333,6 +333,79 @@ class TestSetValueItemTensor6(TestSetValueApi):
self.data[2:0:-1, 0:2, ::-1] = self.value
# 1.5 item is None
class TestSetValueItemNone1(TestSetValueApi):
def _call_setitem(self, x):
x[None] = self.value
def _get_answer(self):
self.data[None] = self.value
class TestSetValueItemNone2(TestSetValueApi):
def _call_setitem(self, x):
x[0, None, 1] = self.value
def _get_answer(self):
self.data[0, None, 1] = self.value
class TestSetValueItemNone3(TestSetValueApi):
def _call_setitem(self, x):
x[:, None, None, 1] = self.value
def _get_answer(self):
self.data[:, None, None, 1] = self.value
class TestSetValueItemNone4(TestSetValueApi):
def _call_setitem(self, x):
x[0, 0, None, 1] = self.value
def _get_answer(self):
self.data[0, 0, None, 1] = self.value
class TestSetValueItemNone5(TestSetValueApi):
def _call_setitem(self, x):
x[0, None, 0, None, 1] = self.value
def _get_answer(self):
self.data[0, None, 0, None, 1] = self.value
class TestSetValueItemNone6(TestSetValueApi):
def _call_setitem(self, x):
x[None, 0, 0, None, 0] = self.value
def _get_answer(self):
self.data[None, 0, 0, None, 0] = self.value
class TestSetValueItemNone7(TestSetValueApi):
def _call_setitem(self, x):
x[:, None, 1] = np.zeros(self.shape)[:, None, 0]
def _get_answer(self):
self.data[:, None, 1] = np.zeros(self.shape)[:, None, 0]
class TestSetValueItemNone8(TestSetValueApi):
def _call_setitem(self, x):
x[:, 1, None] = np.zeros(self.shape)[:, 0, None]
def _get_answer(self):
self.data[:, 1, None] = np.zeros(self.shape)[:, 0, None]
class TestSetValueItemNone9(TestSetValueApi):
def _call_setitem(self, x):
x[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None]
def _get_answer(self):
self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None]
# 2. Test different type of value: int, float, numpy.ndarray, Tensor
# 2.1 value is int32, int64, float32, float64, bool
......@@ -762,8 +835,7 @@ class TestError(TestSetValueBase):
value = np.array([3, 4, 5, 6, 7])
x[0] = value
exe = paddle.static.Executor(paddle.CPUPlace())
with self.assertRaisesRegexp(ValueError,
"Broadcast dimension mismatch."):
with self.assertRaises(ValueError):
exe.run(program)
def test_error(self):
......
......@@ -289,9 +289,11 @@ def _setitem_impl_(var, item, value):
ends = []
steps = []
item, none_axes = replace_none(item)
item = replace_ellipsis(var, item)
for dim, slice_item in enumerate(item):
dim = 0
for _, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item):
decrease_axes.append(dim)
start = slice_item
......@@ -304,6 +306,7 @@ def _setitem_impl_(var, item, value):
step = slice_item.step
if start is None and end is None and step is None:
dim += 1
continue
step = 1 if step is None else step
......@@ -326,7 +329,7 @@ def _setitem_impl_(var, item, value):
end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER)
else:
raise IndexError(
"Valid index accept int or slice or ellipsis, but received {}.".
"Valid index accept int, slice, ellipsis or None, but received {}.".
format(slice_item))
axes.append(dim)
......@@ -334,12 +337,15 @@ def _setitem_impl_(var, item, value):
ends.append(end)
steps.append(step)
dim += 1
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes
'decrease_axes': decrease_axes,
'none_axes': none_axes
}
from .layers import utils
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册