未验证 提交 525c32e3 编写于 作者: L liym27 提交者: GitHub

Fix bug of set_value op:Decerease axes to do right broadcast (#31875)

上级 123949eb
......@@ -124,6 +124,9 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int64_t>>(
"steps", "(list<int64_t>) Stride step from the start to the end.")
.SetDefault({});
AddAttr<std::vector<int64_t>>("decrease_axes",
"(list<int>) The axes to decrease.")
.SetDefault({});
AddAttr<std::vector<int>>("bool_values", "Store the bool values.")
.SetDefault({});
......@@ -185,4 +188,10 @@ Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorLi
"Ending indices of corresponding axis in `axes`.",
std::vector<int64_t>{})
.NewAttr("steps", "Stride step from the start to the end.",
std::vector<int64_t>{}));
std::vector<int64_t>{}))
.AddCheckpoint(
R"ROC(
Upgrade set_value, add 1 attribute [decrease_axes].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"decrease_axes", "The axes to decrease.", std::vector<int64_t>{}));
......@@ -106,10 +106,10 @@ inline void CheckAndUpdateSlice(const framework::DDim in_dims,
}
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<int64_t> axes,
const std::vector<int64_t> starts,
const std::vector<int64_t> ends,
const std::vector<int64_t> steps) {
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& steps) {
framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
......@@ -127,6 +127,38 @@ inline framework::DDim GetSliceDims(const framework::DDim in_dims,
return slice_dims;
}
inline framework::DDim GetDecreasedDims(
const framework::DDim slice_dims,
const std::vector<int64_t>& decrease_axes) {
// Get dims after decreasing axes.
framework::DDim decreased_dims(slice_dims);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
int64_t axis = decrease_axes[i];
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
decreased_dims[axis] = 0;
}
std::vector<int64_t> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
new_shape.push_back(decreased_dims[i]);
}
}
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
// uses [1] instead.
if (new_shape.size() == 0) {
new_shape.push_back(1);
}
decreased_dims = framework::make_ddim(new_shape);
}
return decreased_dims;
}
template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> {
public:
......@@ -179,6 +211,7 @@ class SetValueKernel : public framework::OpKernel<T> {
auto ends = ctx.Attr<std::vector<int64_t>>("ends");
auto steps = ctx.Attr<std::vector<int64_t>>("steps");
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto decrease_axes = ctx.Attr<std::vector<int64_t>>("decrease_axes");
auto dtype = in->type();
if (!starts_tensor_list.empty()) {
......@@ -194,6 +227,7 @@ class SetValueKernel : public framework::OpKernel<T> {
auto in_dims = in->dims();
CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
auto place = ctx.GetPlace();
auto& eigen_place =
......@@ -212,13 +246,13 @@ class SetValueKernel : public framework::OpKernel<T> {
// set_value is what we want.
TensorCopy(*in, place, out);
Tensor slice_t(dtype), pad_t(dtype);
slice_t.mutable_data<T>(slice_dims, place);
pad_t.mutable_data<T>(in_dims, place);
Tensor slice_tensor(dtype), pad_tensor(dtype);
slice_tensor.mutable_data<T>(slice_dims, place);
pad_tensor.mutable_data<T>(in_dims, place);
auto pad_e = framework::EigenTensor<T, D>::From(pad_t, in_dims);
auto pad_e = framework::EigenTensor<T, D>::From(pad_tensor, in_dims);
auto out_e = framework::EigenTensor<T, D>::From(*out);
auto slice_e = framework::EigenTensor<T, D>::From(slice_t, slice_dims);
auto slice_e = framework::EigenTensor<T, D>::From(slice_tensor, slice_dims);
// Step 1: Set the value of out at `_index` to zero
slice_e.device(eigen_place) = slice_e.constant(T(0));
......@@ -244,11 +278,26 @@ class SetValueKernel : public framework::OpKernel<T> {
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
// NOTE(liym27): [ Why resize slice_tensor here? ]
// A: When do broadcasting on slice_tensor and value_tensor, the shape of
// slice_tensor should be decreased dims.
// e.g.
// x[:,0] = value_tensor
// x's shape = [3, 4], value_tensor's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor.Resize(decrease_slice_dims);
if (value_tensor != nullptr) {
// ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_t, value_tensor, -1, SubFunctor<T>(), &slice_t);
ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor);
} else {
Tensor value_t(dtype);
auto value_dims = framework::make_ddim(shape);
......@@ -257,8 +306,9 @@ class SetValueKernel : public framework::OpKernel<T> {
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
value_t.Resize(value_dims);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_t, &value_t, -1, SubFunctor<T>(), &slice_t);
ctx, &slice_tensor, &value_t, -1, SubFunctor<T>(), &slice_tensor);
}
slice_tensor.Resize(slice_dims);
// - Step 2.2 Pad slice tensor with 0
pad_e.device(eigen_place) = pad_e.constant(T(0));
......
......@@ -1863,6 +1863,7 @@ class Variable(object):
if not isinstance(item, tuple):
item = [item]
decrease_axes = []
axes = []
starts = []
ends = []
......@@ -1933,15 +1934,23 @@ class Variable(object):
if end is None:
end = max_integer if step > 0 else (0 - max_integer)
else:
decrease_axes.append(dim)
start = slice_item
end = slice_item + 1 if slice_item != -1 else max_integer
step = 1
axes.append(dim)
starts.append(start)
ends.append(end)
steps.append(step)
attrs = {'axes': axes, 'starts': starts, 'ends': ends, 'steps': steps}
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes
}
from .layers import utils
if utils._contain_var(starts):
......
......@@ -671,6 +671,20 @@ class TestSetValueValueShape4(TestSetValueApi):
self.data[0] = self.value
class TestSetValueValueShape5(TestSetValueApi):
def set_value(self):
self.value = np.array([3, 3, 3]).astype(self.dtype)
def set_shape(self):
self.shape = [3, 4]
def _call_setitem(self, x):
x[:, 0] = paddle.assign(self.value) # x is Paddle.Tensor
def _get_answer(self):
self.data[:, 0] = self.value
# 4. Test error
class TestError(TestSetValueBase):
def _value_type_error(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册