diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index 51e918c852347cdc5a42994f7a06a0ddbc24e305..c135cb82e2ec346d0d8005ede1d1372769a5efd8 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -74,57 +74,50 @@ std::vector DealWithBoolIndices( std::vector* tmp_indices_v) { std::vector res(indices_v.begin(), indices_v.end()); bool contains_bool_tensor = false; + for (size_t i = 0; i < indices_v.size(); ++i) { if (indices_v[i]->dtype() == phi::DataType::BOOL) { contains_bool_tensor = true; + int rank = indices_v[i]->dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1UL, + phi::errors::InvalidArgument("the only bool tensor in indices should " + "have number of dimension at least 1")); + phi::DenseTensor nonzero_indices(phi::DataType::INT64); + nonzero_indices.Resize(phi::make_ddim({-1, rank})); + NonZeroKernel(dev_ctx, *indices_v[i], &nonzero_indices); + + std::vector integer_indices(rank, nullptr); + const int tmp_ix = tmp_indices_v->size(); + for (int i = 0; i < rank; ++i) { + tmp_indices_v->emplace_back( + DenseTensor(phi::DataType::INT64) + .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); + } + for (int i = 0; i < rank; ++i) { + integer_indices[i] = &((*tmp_indices_v)[i + tmp_ix]); + } + SplitWithNumKernel( + dev_ctx, nonzero_indices, rank, 1, integer_indices); + } else if ((indices_v[i]->dtype() == phi::DataType::INT64) || (indices_v[i]->dtype() == phi::DataType::INT32)) { - PADDLE_ENFORCE_EQ( - contains_bool_tensor, - false, - phi::errors::InvalidArgument( - "indices contains bool tensor and int32/int64 tensor at the same " - "time")); + tmp_indices_v->emplace_back(*indices_v[i]); } else { PADDLE_THROW(phi::errors::InvalidArgument( "data type of tensor in indices must be int32, int64 or bool")); } } - if (contains_bool_tensor) { - if (indices_v.size() != 1) { - PADDLE_THROW(phi::errors::InvalidArgument( - "the size of indices must be 1 when it containts bool tensor")); - } - int rank = indices_v[0]->dims().size(); - PADDLE_ENFORCE_GE( - rank, - 1UL, - phi::errors::InvalidArgument("the only bool tensor in indices should " - "have number of dimension at least 1")); - phi::DenseTensor nonzero_indices(phi::DataType::INT64); - nonzero_indices.Resize(phi::make_ddim({-1, rank})); - NonZeroKernel(dev_ctx, *indices_v[0], &nonzero_indices); - - std::vector integer_indices(rank, nullptr); - for (int i = 0; i < rank; ++i) { - tmp_indices_v->emplace_back( - DenseTensor(phi::DataType::INT64) - .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); - } - for (int i = 0; i < rank; ++i) { - integer_indices[i] = &((*tmp_indices_v)[i]); - } - SplitWithNumKernel( - dev_ctx, nonzero_indices, rank, 1, integer_indices); - - std::vector res_tmp(integer_indices.size(), + std::vector res_tmp(tmp_indices_v->size(), nullptr); - for (int i = 0; i < rank; ++i) { + for (size_t i = 0; i < res_tmp.size(); ++i) { res_tmp[i] = &((*tmp_indices_v)[i]); } res.swap(res_tmp); } + return res; } diff --git a/test/legacy_test/test_index_put_op.py b/test/legacy_test/test_index_put_op.py index 5f0257f25535c7f4c82d018deefb98971405e56b..c4bf5d6f0fd401f608eefb5495bee4e4f776f7a6 100644 --- a/test/legacy_test/test_index_put_op.py +++ b/test/legacy_test/test_index_put_op.py @@ -77,13 +77,26 @@ def gen_indices_np(x_shape, indices_shapes, index_type): class TestIndexPutAPIBase(unittest.TestCase): def setUp(self): + self.mixed_indices = False self.init_dtype_type() self.setPlace() self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) - self.indices_np = gen_indices_np( - self.x_shape, self.indices_shapes, self.index_type_np - ) + + if self.mixed_indices: + tmp_indices_np1 = gen_indices_np( + self.x_shape, self.indices_shapes, self.index_type_np + ) + tmp_indices_np2 = gen_indices_np( + self.x_shape, self.indices_shapes1, self.index_type_np1 + ) + self.indices_np = tuple( + list(tmp_indices_np1) + list(tmp_indices_np2) + ) + else: + self.indices_np = gen_indices_np( + self.x_shape, self.indices_shapes, self.index_type_np + ) def init_dtype_type(self): self.dtype_np = np.float64 @@ -109,8 +122,7 @@ class TestIndexPutAPIBase(unittest.TestCase): self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) self.indices_pd = [ - paddle.to_tensor(indice, dtype=self.index_type_pd) - for indice in self.indices_np + paddle.to_tensor(indice) for indice in self.indices_np ] self.indices_pd = tuple(self.indices_pd) ref_res = compute_index_put_ref( @@ -128,16 +140,37 @@ class TestIndexPutAPIBase(unittest.TestCase): x = paddle.static.data( name="x", shape=self.x_shape, dtype=self.dtype_pd ) - indices = tuple( - [ - paddle.static.data( - name="indice" + str(i), - shape=self.indices_shapes[i], - dtype=self.index_type_pd, - ) - for i in range(len(self.indices_shapes)) - ] - ) + if self.mixed_indices: + indices = tuple( + [ + paddle.static.data( + name="indice" + str(i), + shape=self.indices_shapes[i], + dtype=self.index_type_pd, + ) + for i in range(len(self.indices_shapes)) + ] + + [ + paddle.static.data( + name="indice" + + str(i + len(self.indices_shapes)), + shape=self.indices_shapes1[i], + dtype=self.index_type_pd1, + ) + for i in range(len(self.indices_shapes1)) + ] + ) + else: + indices = tuple( + [ + paddle.static.data( + name="indice" + str(i), + shape=self.indices_shapes[i], + dtype=self.index_type_pd, + ) + for i in range(len(self.indices_shapes)) + ] + ) value = paddle.static.data( name="value", shape=self.value_shape, dtype=self.dtype_pd ) @@ -822,5 +855,39 @@ class TestIndexPutAPIBackward(unittest.TestCase): ) +class TestIndexPutAPIMixedIndices(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (110, 42, 32, 56) + self.indices_shapes = ((16, 16), (16, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = False + + self.mixed_indices = True + self.index_type_np1 = np.bool_ + self.indices_shapes1 = [(32,)] + self.index_type_pd1 = paddle.bool + + +class TestIndexPutAPIMixedIndices1(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (110, 42, 32, 56) + self.indices_shapes = ((16, 16), (16, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = True + + self.mixed_indices = True + self.index_type_np1 = np.bool_ + self.indices_shapes1 = [(32,)] + self.index_type_pd1 = paddle.bool + + if __name__ == '__main__': unittest.main()