未验证 提交 525c32e3 编写于 作者: L liym27 提交者: GitHub

Fix bug of set_value op:Decerease axes to do right broadcast (#31875)

上级 123949eb
...@@ -124,6 +124,9 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { ...@@ -124,6 +124,9 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int64_t>>( AddAttr<std::vector<int64_t>>(
"steps", "(list<int64_t>) Stride step from the start to the end.") "steps", "(list<int64_t>) Stride step from the start to the end.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int64_t>>("decrease_axes",
"(list<int>) The axes to decrease.")
.SetDefault({});
AddAttr<std::vector<int>>("bool_values", "Store the bool values.") AddAttr<std::vector<int>>("bool_values", "Store the bool values.")
.SetDefault({}); .SetDefault({});
...@@ -185,4 +188,10 @@ Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorLi ...@@ -185,4 +188,10 @@ Upgrade set_value, add 3 inputs [StartsTensorList, EndsTensorList, StepsTensorLi
"Ending indices of corresponding axis in `axes`.", "Ending indices of corresponding axis in `axes`.",
std::vector<int64_t>{}) std::vector<int64_t>{})
.NewAttr("steps", "Stride step from the start to the end.", .NewAttr("steps", "Stride step from the start to the end.",
std::vector<int64_t>{})); std::vector<int64_t>{}))
.AddCheckpoint(
R"ROC(
Upgrade set_value, add 1 attribute [decrease_axes].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"decrease_axes", "The axes to decrease.", std::vector<int64_t>{}));
...@@ -106,10 +106,10 @@ inline void CheckAndUpdateSlice(const framework::DDim in_dims, ...@@ -106,10 +106,10 @@ inline void CheckAndUpdateSlice(const framework::DDim in_dims,
} }
inline framework::DDim GetSliceDims(const framework::DDim in_dims, inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<int64_t> axes, const std::vector<int64_t>& axes,
const std::vector<int64_t> starts, const std::vector<int64_t>& starts,
const std::vector<int64_t> ends, const std::vector<int64_t>& ends,
const std::vector<int64_t> steps) { const std::vector<int64_t>& steps) {
framework::DDim slice_dims(in_dims); framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
...@@ -127,6 +127,38 @@ inline framework::DDim GetSliceDims(const framework::DDim in_dims, ...@@ -127,6 +127,38 @@ inline framework::DDim GetSliceDims(const framework::DDim in_dims,
return slice_dims; return slice_dims;
} }
inline framework::DDim GetDecreasedDims(
const framework::DDim slice_dims,
const std::vector<int64_t>& decrease_axes) {
// Get dims after decreasing axes.
framework::DDim decreased_dims(slice_dims);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
int64_t axis = decrease_axes[i];
PADDLE_ENFORCE_EQ(
decreased_dims[axis], 1,
platform::errors::InvalidArgument("decrease dim should be 1"));
decreased_dims[axis] = 0;
}
std::vector<int64_t> new_shape;
for (int i = 0; i < decreased_dims.size(); ++i) {
if (decreased_dims[i] != 0) {
new_shape.push_back(decreased_dims[i]);
}
}
// NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and
// uses [1] instead.
if (new_shape.size() == 0) {
new_shape.push_back(1);
}
decreased_dims = framework::make_ddim(new_shape);
}
return decreased_dims;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> { class SetValueKernel : public framework::OpKernel<T> {
public: public:
...@@ -179,6 +211,7 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -179,6 +211,7 @@ class SetValueKernel : public framework::OpKernel<T> {
auto ends = ctx.Attr<std::vector<int64_t>>("ends"); auto ends = ctx.Attr<std::vector<int64_t>>("ends");
auto steps = ctx.Attr<std::vector<int64_t>>("steps"); auto steps = ctx.Attr<std::vector<int64_t>>("steps");
auto shape = ctx.Attr<std::vector<int64_t>>("shape"); auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto decrease_axes = ctx.Attr<std::vector<int64_t>>("decrease_axes");
auto dtype = in->type(); auto dtype = in->type();
if (!starts_tensor_list.empty()) { if (!starts_tensor_list.empty()) {
...@@ -194,6 +227,7 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -194,6 +227,7 @@ class SetValueKernel : public framework::OpKernel<T> {
auto in_dims = in->dims(); auto in_dims = in->dims();
CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps); CheckAndUpdateSlice(in_dims, axes, &starts, &ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps); auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto& eigen_place = auto& eigen_place =
...@@ -212,13 +246,13 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -212,13 +246,13 @@ class SetValueKernel : public framework::OpKernel<T> {
// set_value is what we want. // set_value is what we want.
TensorCopy(*in, place, out); TensorCopy(*in, place, out);
Tensor slice_t(dtype), pad_t(dtype); Tensor slice_tensor(dtype), pad_tensor(dtype);
slice_t.mutable_data<T>(slice_dims, place); slice_tensor.mutable_data<T>(slice_dims, place);
pad_t.mutable_data<T>(in_dims, place); pad_tensor.mutable_data<T>(in_dims, place);
auto pad_e = framework::EigenTensor<T, D>::From(pad_t, in_dims); auto pad_e = framework::EigenTensor<T, D>::From(pad_tensor, in_dims);
auto out_e = framework::EigenTensor<T, D>::From(*out); auto out_e = framework::EigenTensor<T, D>::From(*out);
auto slice_e = framework::EigenTensor<T, D>::From(slice_t, slice_dims); auto slice_e = framework::EigenTensor<T, D>::From(slice_tensor, slice_dims);
// Step 1: Set the value of out at `_index` to zero // Step 1: Set the value of out at `_index` to zero
slice_e.device(eigen_place) = slice_e.constant(T(0)); slice_e.device(eigen_place) = slice_e.constant(T(0));
...@@ -244,11 +278,26 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -244,11 +278,26 @@ class SetValueKernel : public framework::OpKernel<T> {
// Step 2: Set a tensor with the same shape as out tensor. And its data at // Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero // '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value // - Step 2.1 Set slice tensor with value
// NOTE(liym27): [ Why resize slice_tensor here? ]
// A: When do broadcasting on slice_tensor and value_tensor, the shape of
// slice_tensor should be decreased dims.
// e.g.
// x[:,0] = value_tensor
// x's shape = [3, 4], value_tensor's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor.Resize(decrease_slice_dims);
if (value_tensor != nullptr) { if (value_tensor != nullptr) {
// ElementwiseComputeEx can do broadcasting // ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_t, value_tensor, -1, SubFunctor<T>(), &slice_t); ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor);
} else { } else {
Tensor value_t(dtype); Tensor value_t(dtype);
auto value_dims = framework::make_ddim(shape); auto value_dims = framework::make_ddim(shape);
...@@ -257,8 +306,9 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -257,8 +306,9 @@ class SetValueKernel : public framework::OpKernel<T> {
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx); CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
value_t.Resize(value_dims); value_t.Resize(value_dims);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_t, &value_t, -1, SubFunctor<T>(), &slice_t); ctx, &slice_tensor, &value_t, -1, SubFunctor<T>(), &slice_tensor);
} }
slice_tensor.Resize(slice_dims);
// - Step 2.2 Pad slice tensor with 0 // - Step 2.2 Pad slice tensor with 0
pad_e.device(eigen_place) = pad_e.constant(T(0)); pad_e.device(eigen_place) = pad_e.constant(T(0));
......
...@@ -1863,6 +1863,7 @@ class Variable(object): ...@@ -1863,6 +1863,7 @@ class Variable(object):
if not isinstance(item, tuple): if not isinstance(item, tuple):
item = [item] item = [item]
decrease_axes = []
axes = [] axes = []
starts = [] starts = []
ends = [] ends = []
...@@ -1933,15 +1934,23 @@ class Variable(object): ...@@ -1933,15 +1934,23 @@ class Variable(object):
if end is None: if end is None:
end = max_integer if step > 0 else (0 - max_integer) end = max_integer if step > 0 else (0 - max_integer)
else: else:
decrease_axes.append(dim)
start = slice_item start = slice_item
end = slice_item + 1 if slice_item != -1 else max_integer end = slice_item + 1 if slice_item != -1 else max_integer
step = 1 step = 1
axes.append(dim) axes.append(dim)
starts.append(start) starts.append(start)
ends.append(end) ends.append(end)
steps.append(step) steps.append(step)
attrs = {'axes': axes, 'starts': starts, 'ends': ends, 'steps': steps} attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes
}
from .layers import utils from .layers import utils
if utils._contain_var(starts): if utils._contain_var(starts):
......
...@@ -671,6 +671,20 @@ class TestSetValueValueShape4(TestSetValueApi): ...@@ -671,6 +671,20 @@ class TestSetValueValueShape4(TestSetValueApi):
self.data[0] = self.value self.data[0] = self.value
class TestSetValueValueShape5(TestSetValueApi):
def set_value(self):
self.value = np.array([3, 3, 3]).astype(self.dtype)
def set_shape(self):
self.shape = [3, 4]
def _call_setitem(self, x):
x[:, 0] = paddle.assign(self.value) # x is Paddle.Tensor
def _get_answer(self):
self.data[:, 0] = self.value
# 4. Test error # 4. Test error
class TestError(TestSetValueBase): class TestError(TestSetValueBase):
def _value_type_error(self): def _value_type_error(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册