未验证 提交 8fd4ef91 编写于 作者: 傅剑寒 提交者: GitHub

add mixed bool and int index support for index_put (#54195)

上级 9a8e9417
...@@ -74,57 +74,50 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices( ...@@ -74,57 +74,50 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
std::vector<phi::DenseTensor>* tmp_indices_v) { std::vector<phi::DenseTensor>* tmp_indices_v) {
std::vector<const phi::DenseTensor*> res(indices_v.begin(), indices_v.end()); std::vector<const phi::DenseTensor*> res(indices_v.begin(), indices_v.end());
bool contains_bool_tensor = false; bool contains_bool_tensor = false;
for (size_t i = 0; i < indices_v.size(); ++i) { for (size_t i = 0; i < indices_v.size(); ++i) {
if (indices_v[i]->dtype() == phi::DataType::BOOL) { if (indices_v[i]->dtype() == phi::DataType::BOOL) {
contains_bool_tensor = true; contains_bool_tensor = true;
int rank = indices_v[i]->dims().size();
PADDLE_ENFORCE_GE(
rank,
1UL,
phi::errors::InvalidArgument("the only bool tensor in indices should "
"have number of dimension at least 1"));
phi::DenseTensor nonzero_indices(phi::DataType::INT64);
nonzero_indices.Resize(phi::make_ddim({-1, rank}));
NonZeroKernel<bool, Context>(dev_ctx, *indices_v[i], &nonzero_indices);
std::vector<phi::DenseTensor*> integer_indices(rank, nullptr);
const int tmp_ix = tmp_indices_v->size();
for (int i = 0; i < rank; ++i) {
tmp_indices_v->emplace_back(
DenseTensor(phi::DataType::INT64)
.Resize(phi::make_ddim({nonzero_indices.dims()[0]})));
}
for (int i = 0; i < rank; ++i) {
integer_indices[i] = &((*tmp_indices_v)[i + tmp_ix]);
}
SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices);
} else if ((indices_v[i]->dtype() == phi::DataType::INT64) || } else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||
(indices_v[i]->dtype() == phi::DataType::INT32)) { (indices_v[i]->dtype() == phi::DataType::INT32)) {
PADDLE_ENFORCE_EQ( tmp_indices_v->emplace_back(*indices_v[i]);
contains_bool_tensor,
false,
phi::errors::InvalidArgument(
"indices contains bool tensor and int32/int64 tensor at the same "
"time"));
} else { } else {
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"data type of tensor in indices must be int32, int64 or bool")); "data type of tensor in indices must be int32, int64 or bool"));
} }
} }
if (contains_bool_tensor) { if (contains_bool_tensor) {
if (indices_v.size() != 1) { std::vector<const phi::DenseTensor*> res_tmp(tmp_indices_v->size(),
PADDLE_THROW(phi::errors::InvalidArgument(
"the size of indices must be 1 when it containts bool tensor"));
}
int rank = indices_v[0]->dims().size();
PADDLE_ENFORCE_GE(
rank,
1UL,
phi::errors::InvalidArgument("the only bool tensor in indices should "
"have number of dimension at least 1"));
phi::DenseTensor nonzero_indices(phi::DataType::INT64);
nonzero_indices.Resize(phi::make_ddim({-1, rank}));
NonZeroKernel<bool, Context>(dev_ctx, *indices_v[0], &nonzero_indices);
std::vector<phi::DenseTensor*> integer_indices(rank, nullptr);
for (int i = 0; i < rank; ++i) {
tmp_indices_v->emplace_back(
DenseTensor(phi::DataType::INT64)
.Resize(phi::make_ddim({nonzero_indices.dims()[0]})));
}
for (int i = 0; i < rank; ++i) {
integer_indices[i] = &((*tmp_indices_v)[i]);
}
SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices);
std::vector<const phi::DenseTensor*> res_tmp(integer_indices.size(),
nullptr); nullptr);
for (int i = 0; i < rank; ++i) { for (size_t i = 0; i < res_tmp.size(); ++i) {
res_tmp[i] = &((*tmp_indices_v)[i]); res_tmp[i] = &((*tmp_indices_v)[i]);
} }
res.swap(res_tmp); res.swap(res_tmp);
} }
return res; return res;
} }
......
...@@ -77,13 +77,26 @@ def gen_indices_np(x_shape, indices_shapes, index_type): ...@@ -77,13 +77,26 @@ def gen_indices_np(x_shape, indices_shapes, index_type):
class TestIndexPutAPIBase(unittest.TestCase): class TestIndexPutAPIBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.mixed_indices = False
self.init_dtype_type() self.init_dtype_type()
self.setPlace() self.setPlace()
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) self.x_np = np.random.random(self.x_shape).astype(self.dtype_np)
self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) self.value_np = np.random.random(self.value_shape).astype(self.dtype_np)
self.indices_np = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np if self.mixed_indices:
) tmp_indices_np1 = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np
)
tmp_indices_np2 = gen_indices_np(
self.x_shape, self.indices_shapes1, self.index_type_np1
)
self.indices_np = tuple(
list(tmp_indices_np1) + list(tmp_indices_np2)
)
else:
self.indices_np = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np
)
def init_dtype_type(self): def init_dtype_type(self):
self.dtype_np = np.float64 self.dtype_np = np.float64
...@@ -109,8 +122,7 @@ class TestIndexPutAPIBase(unittest.TestCase): ...@@ -109,8 +122,7 @@ class TestIndexPutAPIBase(unittest.TestCase):
self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd)
self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd)
self.indices_pd = [ self.indices_pd = [
paddle.to_tensor(indice, dtype=self.index_type_pd) paddle.to_tensor(indice) for indice in self.indices_np
for indice in self.indices_np
] ]
self.indices_pd = tuple(self.indices_pd) self.indices_pd = tuple(self.indices_pd)
ref_res = compute_index_put_ref( ref_res = compute_index_put_ref(
...@@ -128,16 +140,37 @@ class TestIndexPutAPIBase(unittest.TestCase): ...@@ -128,16 +140,37 @@ class TestIndexPutAPIBase(unittest.TestCase):
x = paddle.static.data( x = paddle.static.data(
name="x", shape=self.x_shape, dtype=self.dtype_pd name="x", shape=self.x_shape, dtype=self.dtype_pd
) )
indices = tuple( if self.mixed_indices:
[ indices = tuple(
paddle.static.data( [
name="indice" + str(i), paddle.static.data(
shape=self.indices_shapes[i], name="indice" + str(i),
dtype=self.index_type_pd, shape=self.indices_shapes[i],
) dtype=self.index_type_pd,
for i in range(len(self.indices_shapes)) )
] for i in range(len(self.indices_shapes))
) ]
+ [
paddle.static.data(
name="indice"
+ str(i + len(self.indices_shapes)),
shape=self.indices_shapes1[i],
dtype=self.index_type_pd1,
)
for i in range(len(self.indices_shapes1))
]
)
else:
indices = tuple(
[
paddle.static.data(
name="indice" + str(i),
shape=self.indices_shapes[i],
dtype=self.index_type_pd,
)
for i in range(len(self.indices_shapes))
]
)
value = paddle.static.data( value = paddle.static.data(
name="value", shape=self.value_shape, dtype=self.dtype_pd name="value", shape=self.value_shape, dtype=self.dtype_pd
) )
...@@ -822,5 +855,39 @@ class TestIndexPutAPIBackward(unittest.TestCase): ...@@ -822,5 +855,39 @@ class TestIndexPutAPIBackward(unittest.TestCase):
) )
class TestIndexPutAPIMixedIndices(TestIndexPutAPIBase):
def init_dtype_type(self):
self.dtype_np = np.float64
self.index_type_np = np.int32
self.x_shape = (110, 42, 32, 56)
self.indices_shapes = ((16, 16), (16, 16))
self.value_shape = (16, 16, 56)
self.dtype_pd = paddle.float64
self.index_type_pd = paddle.int32
self.accumulate = False
self.mixed_indices = True
self.index_type_np1 = np.bool_
self.indices_shapes1 = [(32,)]
self.index_type_pd1 = paddle.bool
class TestIndexPutAPIMixedIndices1(TestIndexPutAPIBase):
def init_dtype_type(self):
self.dtype_np = np.float64
self.index_type_np = np.int32
self.x_shape = (110, 42, 32, 56)
self.indices_shapes = ((16, 16), (16, 16))
self.value_shape = (16, 16, 56)
self.dtype_pd = paddle.float64
self.index_type_pd = paddle.int32
self.accumulate = True
self.mixed_indices = True
self.index_type_np1 = np.bool_
self.indices_shapes1 = [(32,)]
self.index_type_pd1 = paddle.bool
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册