未验证 提交 f775bfc1 编写于 作者: Z zyfncg 提交者: GitHub

Support setitem by None index (#34442)

* Support setitem by None index

* remove unreachable code

* Add Checkpoint for set_value_op because add a new attribute
上级 30416052
...@@ -127,6 +127,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { ...@@ -127,6 +127,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int64_t>>("decrease_axes", AddAttr<std::vector<int64_t>>("decrease_axes",
"(list<int>) The axes to decrease.") "(list<int>) The axes to decrease.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int64_t>>("none_axes", "(list<int>) The axes to none.")
.SetDefault({});
AddAttr<std::vector<int>>("bool_values", "Store the bool values.") AddAttr<std::vector<int>>("bool_values", "Store the bool values.")
.SetDefault({}); .SetDefault({});
...@@ -247,4 +249,10 @@ Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorLi ...@@ -247,4 +249,10 @@ Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorLi
Upgrade set_value, add 1 attribute [decrease_axes]. Upgrade set_value, add 1 attribute [decrease_axes].
)ROC", )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr( paddle::framework::compatible::OpVersionDesc().NewAttr(
"decrease_axes", "The axes to decrease.", std::vector<int64_t>{})); "decrease_axes", "The axes to decrease.", std::vector<int64_t>{}))
.AddCheckpoint(
R"ROC(
Upgrade set_value, add 1 attribute [none_axes].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"none_axes", "The axes with none index.", std::vector<int64_t>{}));
...@@ -60,6 +60,47 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) { ...@@ -60,6 +60,47 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) {
return value_name; return value_name;
} }
// check whether the tensor with dimension of second can assign to the
// tensor with dimension of first
inline void CheckIsDimsMatch(const framework::DDim first,
const framework::DDim second) {
int ignore_axis1 = 0, ignore_axis2 = 0;
for (; ignore_axis1 < first.size(); ++ignore_axis1) {
if (first[ignore_axis1] != 1) {
break;
}
}
for (; ignore_axis2 < second.size(); ++ignore_axis2) {
if (second[ignore_axis2] != 1) {
break;
}
}
if (second.size() == ignore_axis2) {
// second tensor has only one value
return;
}
if (first.size() - ignore_axis1 >= second.size() - ignore_axis2) {
auto idx1 = first.size() - 1;
auto idx2 = second.size() - 1;
bool is_match = true;
for (; idx2 >= ignore_axis2; idx2--) {
if (first[idx1--] != second[idx2] && second[idx2] != 1) {
is_match = false;
break;
}
}
if (is_match) {
return;
}
}
PADDLE_THROW(platform::errors::InvalidArgument(
"The shape of tensor assigned value must match the shape "
"of target shape: %d, but now shape is %d.",
second.to_str(), first.to_str()));
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> { class SetValueKernel : public framework::OpKernel<T> {
public: public:
...@@ -113,6 +154,7 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -113,6 +154,7 @@ class SetValueKernel : public framework::OpKernel<T> {
auto steps = ctx.Attr<std::vector<int64_t>>("steps"); auto steps = ctx.Attr<std::vector<int64_t>>("steps");
auto shape = ctx.Attr<std::vector<int64_t>>("shape"); auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto decrease_axes = ctx.Attr<std::vector<int64_t>>("decrease_axes"); auto decrease_axes = ctx.Attr<std::vector<int64_t>>("decrease_axes");
auto none_axes = ctx.Attr<std::vector<int64_t>>("none_axes");
auto dtype = in->type(); auto dtype = in->type();
if (!starts_tensor_list.empty()) { if (!starts_tensor_list.empty()) {
...@@ -130,6 +172,32 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -130,6 +172,32 @@ class SetValueKernel : public framework::OpKernel<T> {
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps); auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none;
size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() &&
none_axes[none_axes_cur] <= i) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
if (decrease_axes_cur < decrease_axes.size() &&
decrease_axes[decrease_axes_cur] == i) {
decrease_axes_cur++;
} else {
slice_dims_with_none.push_back(slice_dims[i]);
}
}
while (none_axes_cur < none_axes.size()) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
slice_dims_for_assign = framework::make_ddim(slice_dims_with_none);
}
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto& eigen_place = auto& eigen_place =
*ctx.template device_context<DeviceContext>().eigen_device(); *ctx.template device_context<DeviceContext>().eigen_device();
...@@ -194,14 +262,17 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -194,14 +262,17 @@ class SetValueKernel : public framework::OpKernel<T> {
// If do broadcasting on Tensor with shape [3] and [3], the result's shape // If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right. // is [3], which is right.
slice_tensor.Resize(decrease_slice_dims); slice_tensor.Resize(slice_dims_for_assign);
if (value_tensor != nullptr) { if (value_tensor != nullptr) {
CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims());
// ElementwiseComputeEx can do broadcasting // ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor); ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor);
} else { } else {
Tensor value_t(dtype); Tensor value_t(dtype);
auto value_dims = framework::make_ddim(shape); auto value_dims = framework::make_ddim(shape);
CheckIsDimsMatch(slice_dims_for_assign, value_dims);
value_t.mutable_data<T>(value_dims, place); value_t.mutable_data<T>(value_dims, place);
auto value_name = GetValueName(dtype); auto value_name = GetValueName(dtype);
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx); CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
......
...@@ -333,6 +333,79 @@ class TestSetValueItemTensor6(TestSetValueApi): ...@@ -333,6 +333,79 @@ class TestSetValueItemTensor6(TestSetValueApi):
self.data[2:0:-1, 0:2, ::-1] = self.value self.data[2:0:-1, 0:2, ::-1] = self.value
# 1.5 item is None
class TestSetValueItemNone1(TestSetValueApi):
def _call_setitem(self, x):
x[None] = self.value
def _get_answer(self):
self.data[None] = self.value
class TestSetValueItemNone2(TestSetValueApi):
def _call_setitem(self, x):
x[0, None, 1] = self.value
def _get_answer(self):
self.data[0, None, 1] = self.value
class TestSetValueItemNone3(TestSetValueApi):
def _call_setitem(self, x):
x[:, None, None, 1] = self.value
def _get_answer(self):
self.data[:, None, None, 1] = self.value
class TestSetValueItemNone4(TestSetValueApi):
def _call_setitem(self, x):
x[0, 0, None, 1] = self.value
def _get_answer(self):
self.data[0, 0, None, 1] = self.value
class TestSetValueItemNone5(TestSetValueApi):
def _call_setitem(self, x):
x[0, None, 0, None, 1] = self.value
def _get_answer(self):
self.data[0, None, 0, None, 1] = self.value
class TestSetValueItemNone6(TestSetValueApi):
def _call_setitem(self, x):
x[None, 0, 0, None, 0] = self.value
def _get_answer(self):
self.data[None, 0, 0, None, 0] = self.value
class TestSetValueItemNone7(TestSetValueApi):
def _call_setitem(self, x):
x[:, None, 1] = np.zeros(self.shape)[:, None, 0]
def _get_answer(self):
self.data[:, None, 1] = np.zeros(self.shape)[:, None, 0]
class TestSetValueItemNone8(TestSetValueApi):
def _call_setitem(self, x):
x[:, 1, None] = np.zeros(self.shape)[:, 0, None]
def _get_answer(self):
self.data[:, 1, None] = np.zeros(self.shape)[:, 0, None]
class TestSetValueItemNone9(TestSetValueApi):
def _call_setitem(self, x):
x[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None]
def _get_answer(self):
self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None]
# 2. Test different type of value: int, float, numpy.ndarray, Tensor # 2. Test different type of value: int, float, numpy.ndarray, Tensor
# 2.1 value is int32, int64, float32, float64, bool # 2.1 value is int32, int64, float32, float64, bool
...@@ -762,8 +835,7 @@ class TestError(TestSetValueBase): ...@@ -762,8 +835,7 @@ class TestError(TestSetValueBase):
value = np.array([3, 4, 5, 6, 7]) value = np.array([3, 4, 5, 6, 7])
x[0] = value x[0] = value
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
with self.assertRaisesRegexp(ValueError, with self.assertRaises(ValueError):
"Broadcast dimension mismatch."):
exe.run(program) exe.run(program)
def test_error(self): def test_error(self):
......
...@@ -289,9 +289,11 @@ def _setitem_impl_(var, item, value): ...@@ -289,9 +289,11 @@ def _setitem_impl_(var, item, value):
ends = [] ends = []
steps = [] steps = []
item, none_axes = replace_none(item)
item = replace_ellipsis(var, item) item = replace_ellipsis(var, item)
for dim, slice_item in enumerate(item): dim = 0
for _, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item): if is_integer_or_scalar_tensor(slice_item):
decrease_axes.append(dim) decrease_axes.append(dim)
start = slice_item start = slice_item
...@@ -304,6 +306,7 @@ def _setitem_impl_(var, item, value): ...@@ -304,6 +306,7 @@ def _setitem_impl_(var, item, value):
step = slice_item.step step = slice_item.step
if start is None and end is None and step is None: if start is None and end is None and step is None:
dim += 1
continue continue
step = 1 if step is None else step step = 1 if step is None else step
...@@ -326,7 +329,7 @@ def _setitem_impl_(var, item, value): ...@@ -326,7 +329,7 @@ def _setitem_impl_(var, item, value):
end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER) end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER)
else: else:
raise IndexError( raise IndexError(
"Valid index accept int or slice or ellipsis, but received {}.". "Valid index accept int, slice, ellipsis or None, but received {}.".
format(slice_item)) format(slice_item))
axes.append(dim) axes.append(dim)
...@@ -334,12 +337,15 @@ def _setitem_impl_(var, item, value): ...@@ -334,12 +337,15 @@ def _setitem_impl_(var, item, value):
ends.append(end) ends.append(end)
steps.append(step) steps.append(step)
dim += 1
attrs = { attrs = {
'axes': axes, 'axes': axes,
'starts': starts, 'starts': starts,
'ends': ends, 'ends': ends,
'steps': steps, 'steps': steps,
'decrease_axes': decrease_axes 'decrease_axes': decrease_axes,
'none_axes': none_axes
} }
from .layers import utils from .layers import utils
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册