未验证 提交 8fd4ef91 编写于 作者: 傅剑寒 提交者: GitHub

add mixed bool and int index support for index_put (#54195)

上级 9a8e9417
......@@ -74,57 +74,50 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
std::vector<phi::DenseTensor>* tmp_indices_v) {
std::vector<const phi::DenseTensor*> res(indices_v.begin(), indices_v.end());
bool contains_bool_tensor = false;
for (size_t i = 0; i < indices_v.size(); ++i) {
if (indices_v[i]->dtype() == phi::DataType::BOOL) {
contains_bool_tensor = true;
int rank = indices_v[i]->dims().size();
PADDLE_ENFORCE_GE(
rank,
1UL,
phi::errors::InvalidArgument("the only bool tensor in indices should "
"have number of dimension at least 1"));
phi::DenseTensor nonzero_indices(phi::DataType::INT64);
nonzero_indices.Resize(phi::make_ddim({-1, rank}));
NonZeroKernel<bool, Context>(dev_ctx, *indices_v[i], &nonzero_indices);
std::vector<phi::DenseTensor*> integer_indices(rank, nullptr);
const int tmp_ix = tmp_indices_v->size();
for (int i = 0; i < rank; ++i) {
tmp_indices_v->emplace_back(
DenseTensor(phi::DataType::INT64)
.Resize(phi::make_ddim({nonzero_indices.dims()[0]})));
}
for (int i = 0; i < rank; ++i) {
integer_indices[i] = &((*tmp_indices_v)[i + tmp_ix]);
}
SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices);
} else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||
(indices_v[i]->dtype() == phi::DataType::INT32)) {
PADDLE_ENFORCE_EQ(
contains_bool_tensor,
false,
phi::errors::InvalidArgument(
"indices contains bool tensor and int32/int64 tensor at the same "
"time"));
tmp_indices_v->emplace_back(*indices_v[i]);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"data type of tensor in indices must be int32, int64 or bool"));
}
}
if (contains_bool_tensor) {
if (indices_v.size() != 1) {
PADDLE_THROW(phi::errors::InvalidArgument(
"the size of indices must be 1 when it containts bool tensor"));
}
int rank = indices_v[0]->dims().size();
PADDLE_ENFORCE_GE(
rank,
1UL,
phi::errors::InvalidArgument("the only bool tensor in indices should "
"have number of dimension at least 1"));
phi::DenseTensor nonzero_indices(phi::DataType::INT64);
nonzero_indices.Resize(phi::make_ddim({-1, rank}));
NonZeroKernel<bool, Context>(dev_ctx, *indices_v[0], &nonzero_indices);
std::vector<phi::DenseTensor*> integer_indices(rank, nullptr);
for (int i = 0; i < rank; ++i) {
tmp_indices_v->emplace_back(
DenseTensor(phi::DataType::INT64)
.Resize(phi::make_ddim({nonzero_indices.dims()[0]})));
}
for (int i = 0; i < rank; ++i) {
integer_indices[i] = &((*tmp_indices_v)[i]);
}
SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices);
std::vector<const phi::DenseTensor*> res_tmp(integer_indices.size(),
std::vector<const phi::DenseTensor*> res_tmp(tmp_indices_v->size(),
nullptr);
for (int i = 0; i < rank; ++i) {
for (size_t i = 0; i < res_tmp.size(); ++i) {
res_tmp[i] = &((*tmp_indices_v)[i]);
}
res.swap(res_tmp);
}
return res;
}
......
......@@ -77,13 +77,26 @@ def gen_indices_np(x_shape, indices_shapes, index_type):
class TestIndexPutAPIBase(unittest.TestCase):
def setUp(self):
self.mixed_indices = False
self.init_dtype_type()
self.setPlace()
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np)
self.value_np = np.random.random(self.value_shape).astype(self.dtype_np)
self.indices_np = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np
)
if self.mixed_indices:
tmp_indices_np1 = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np
)
tmp_indices_np2 = gen_indices_np(
self.x_shape, self.indices_shapes1, self.index_type_np1
)
self.indices_np = tuple(
list(tmp_indices_np1) + list(tmp_indices_np2)
)
else:
self.indices_np = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np
)
def init_dtype_type(self):
self.dtype_np = np.float64
......@@ -109,8 +122,7 @@ class TestIndexPutAPIBase(unittest.TestCase):
self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd)
self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd)
self.indices_pd = [
paddle.to_tensor(indice, dtype=self.index_type_pd)
for indice in self.indices_np
paddle.to_tensor(indice) for indice in self.indices_np
]
self.indices_pd = tuple(self.indices_pd)
ref_res = compute_index_put_ref(
......@@ -128,16 +140,37 @@ class TestIndexPutAPIBase(unittest.TestCase):
x = paddle.static.data(
name="x", shape=self.x_shape, dtype=self.dtype_pd
)
indices = tuple(
[
paddle.static.data(
name="indice" + str(i),
shape=self.indices_shapes[i],
dtype=self.index_type_pd,
)
for i in range(len(self.indices_shapes))
]
)
if self.mixed_indices:
indices = tuple(
[
paddle.static.data(
name="indice" + str(i),
shape=self.indices_shapes[i],
dtype=self.index_type_pd,
)
for i in range(len(self.indices_shapes))
]
+ [
paddle.static.data(
name="indice"
+ str(i + len(self.indices_shapes)),
shape=self.indices_shapes1[i],
dtype=self.index_type_pd1,
)
for i in range(len(self.indices_shapes1))
]
)
else:
indices = tuple(
[
paddle.static.data(
name="indice" + str(i),
shape=self.indices_shapes[i],
dtype=self.index_type_pd,
)
for i in range(len(self.indices_shapes))
]
)
value = paddle.static.data(
name="value", shape=self.value_shape, dtype=self.dtype_pd
)
......@@ -822,5 +855,39 @@ class TestIndexPutAPIBackward(unittest.TestCase):
)
class TestIndexPutAPIMixedIndices(TestIndexPutAPIBase):
def init_dtype_type(self):
self.dtype_np = np.float64
self.index_type_np = np.int32
self.x_shape = (110, 42, 32, 56)
self.indices_shapes = ((16, 16), (16, 16))
self.value_shape = (16, 16, 56)
self.dtype_pd = paddle.float64
self.index_type_pd = paddle.int32
self.accumulate = False
self.mixed_indices = True
self.index_type_np1 = np.bool_
self.indices_shapes1 = [(32,)]
self.index_type_pd1 = paddle.bool
class TestIndexPutAPIMixedIndices1(TestIndexPutAPIBase):
def init_dtype_type(self):
self.dtype_np = np.float64
self.index_type_np = np.int32
self.x_shape = (110, 42, 32, 56)
self.indices_shapes = ((16, 16), (16, 16))
self.value_shape = (16, 16, 56)
self.dtype_pd = paddle.float64
self.index_type_pd = paddle.int32
self.accumulate = True
self.mixed_indices = True
self.index_type_np1 = np.bool_
self.indices_shapes1 = [(32,)]
self.index_type_pd1 = paddle.bool
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册