未验证 提交 c737f0ae 编写于 作者: 傅剑寒 提交者: GitHub

add all false bool indices support for index_put (#55655)

上级 7da1ffbe
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <numeric> #include <numeric>
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h" #include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h"
...@@ -188,6 +189,19 @@ void IndexPutGradKernel(const Context& dev_ctx, ...@@ -188,6 +189,19 @@ void IndexPutGradKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_args; std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v = std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args); funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
if (int_indices_v.empty()) {
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
}
if (value_grad) {
FullKernel<T, Context>(dev_ctx,
phi::vectorize(value_grad->dims()),
0.0f,
value_grad->dtype(),
value_grad);
}
return;
}
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
std::vector<int64_t> res_dim_v(phi::vectorize(bd_dim)); std::vector<int64_t> res_dim_v(phi::vectorize(bd_dim));
......
...@@ -117,6 +117,12 @@ void IndexPutKernel(const Context& dev_ctx, ...@@ -117,6 +117,12 @@ void IndexPutKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_args; std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v = std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args); funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
if (int_indices_v.empty()) {
if (!out->initialized()) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
return;
}
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
......
...@@ -88,6 +88,11 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices( ...@@ -88,6 +88,11 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
nonzero_indices.Resize(phi::make_ddim({-1, rank})); nonzero_indices.Resize(phi::make_ddim({-1, rank}));
NonZeroKernel<bool, Context>(dev_ctx, *indices_v[i], &nonzero_indices); NonZeroKernel<bool, Context>(dev_ctx, *indices_v[i], &nonzero_indices);
if (nonzero_indices.numel() == 0) {
std::vector<const phi::DenseTensor*> empty_indices;
return empty_indices;
}
std::vector<phi::DenseTensor*> integer_indices(rank, nullptr); std::vector<phi::DenseTensor*> integer_indices(rank, nullptr);
const int tmp_ix = tmp_indices_v->size(); const int tmp_ix = tmp_indices_v->size();
for (int i = 0; i < rank; ++i) { for (int i = 0; i < rank; ++i) {
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h" #include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h"
...@@ -219,6 +220,20 @@ void IndexPutGradKernel(const Context& dev_ctx, ...@@ -219,6 +220,20 @@ void IndexPutGradKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_args; std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v = std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args); funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
if (int_indices_v.empty()) {
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
}
if (value_grad) {
FullKernel<T, Context>(dev_ctx,
phi::vectorize(value_grad->dims()),
0.0f,
value_grad->dtype(),
value_grad);
}
return;
}
const size_t total_dims = x.dims().size(); const size_t total_dims = x.dims().size();
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
......
...@@ -118,6 +118,12 @@ void IndexPutKernel(const Context& dev_ctx, ...@@ -118,6 +118,12 @@ void IndexPutKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_args; std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v = std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args); funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
if (int_indices_v.empty()) {
if (!out->initialized()) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
return;
}
const size_t total_dims = x.dims().size(); const size_t total_dims = x.dims().size();
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
......
...@@ -47,14 +47,15 @@ def has_duplicate_index(indices, shapes): ...@@ -47,14 +47,15 @@ def has_duplicate_index(indices, shapes):
return True return True
def gen_indices_np(x_shape, indices_shapes, index_type): def gen_indices_np(x_shape, indices_shapes, index_type, is_all_false):
indices = [] indices = []
if index_type == np.bool_: if index_type == np.bool_:
indice = np.zeros(indices_shapes[0], dtype=np.bool_) indice = np.zeros(indices_shapes[0], dtype=np.bool_)
indice.flatten() if not is_all_false:
for i in range(len(indice)): indice.flatten()
indice[i] = (i & 1) == 0 for i in range(len(indice)):
indice = indice.reshape(indices_shapes[0]) indice[i] = (i & 1) == 0
indice = indice.reshape(indices_shapes[0])
indices.append(indice) indices.append(indice)
else: else:
while True: while True:
...@@ -78,6 +79,7 @@ def gen_indices_np(x_shape, indices_shapes, index_type): ...@@ -78,6 +79,7 @@ def gen_indices_np(x_shape, indices_shapes, index_type):
class TestIndexPutAPIBase(unittest.TestCase): class TestIndexPutAPIBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.mixed_indices = False self.mixed_indices = False
self.is_all_false = False
self.init_dtype_type() self.init_dtype_type()
self.setPlace() self.setPlace()
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) self.x_np = np.random.random(self.x_shape).astype(self.dtype_np)
...@@ -85,17 +87,26 @@ class TestIndexPutAPIBase(unittest.TestCase): ...@@ -85,17 +87,26 @@ class TestIndexPutAPIBase(unittest.TestCase):
if self.mixed_indices: if self.mixed_indices:
tmp_indices_np1 = gen_indices_np( tmp_indices_np1 = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np self.x_shape,
self.indices_shapes,
self.index_type_np,
self.is_all_false,
) )
tmp_indices_np2 = gen_indices_np( tmp_indices_np2 = gen_indices_np(
self.x_shape, self.indices_shapes1, self.index_type_np1 self.x_shape,
self.indices_shapes1,
self.index_type_np1,
self.is_all_false,
) )
self.indices_np = tuple( self.indices_np = tuple(
list(tmp_indices_np1) + list(tmp_indices_np2) list(tmp_indices_np1) + list(tmp_indices_np2)
) )
else: else:
self.indices_np = gen_indices_np( self.indices_np = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np self.x_shape,
self.indices_shapes,
self.index_type_np,
self.is_all_false,
) )
def init_dtype_type(self): def init_dtype_type(self):
...@@ -565,6 +576,32 @@ class TestIndexPutAPI30(TestIndexPutAPIBase): ...@@ -565,6 +576,32 @@ class TestIndexPutAPI30(TestIndexPutAPIBase):
self.accumulate = True self.accumulate = True
class TestIndexPutAPI31(TestIndexPutAPIBase):
def init_dtype_type(self):
self.dtype_np = np.bool_
self.index_type_np = np.int32
self.x_shape = (100, 110)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.dtype_pd = paddle.bool
self.index_type_pd = paddle.int32
self.accumulate = False
self.is_all_false = True
class TestIndexPutAPI32(TestIndexPutAPIBase):
def init_dtype_type(self):
self.dtype_np = np.bool_
self.index_type_np = np.int32
self.x_shape = (100, 110)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.dtype_pd = paddle.bool
self.index_type_pd = paddle.int32
self.accumulate = True
self.is_all_false = True
class TestIndexPutInplaceAPI(unittest.TestCase): class TestIndexPutInplaceAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.init_dtype_type() self.init_dtype_type()
...@@ -572,7 +609,7 @@ class TestIndexPutInplaceAPI(unittest.TestCase): ...@@ -572,7 +609,7 @@ class TestIndexPutInplaceAPI(unittest.TestCase):
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) self.x_np = np.random.random(self.x_shape).astype(self.dtype_np)
self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) self.value_np = np.random.random(self.value_shape).astype(self.dtype_np)
self.indices_np = gen_indices_np( self.indices_np = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np self.x_shape, self.indices_shapes, self.index_type_np, False
) )
def init_dtype_type(self): def init_dtype_type(self):
...@@ -678,7 +715,7 @@ class TestIndexPutAPIBackward(unittest.TestCase): ...@@ -678,7 +715,7 @@ class TestIndexPutAPIBackward(unittest.TestCase):
atol=1e-7, atol=1e-7,
) )
def test_backwardScalarVal(self): def test_backward_scalarval(self):
paddle.disable_static() paddle.disable_static()
for place in self.place: for place in self.place:
paddle.device.set_device(place) paddle.device.set_device(place)
...@@ -719,7 +756,7 @@ class TestIndexPutAPIBackward(unittest.TestCase): ...@@ -719,7 +756,7 @@ class TestIndexPutAPIBackward(unittest.TestCase):
np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7
) )
def test_backwardBroadCastValue(self): def test_backward_broadcastvalue(self):
paddle.disable_static() paddle.disable_static()
for place in self.place: for place in self.place:
paddle.device.set_device(place) paddle.device.set_device(place)
...@@ -764,7 +801,7 @@ class TestIndexPutAPIBackward(unittest.TestCase): ...@@ -764,7 +801,7 @@ class TestIndexPutAPIBackward(unittest.TestCase):
atol=1e-7, atol=1e-7,
) )
def test_backwardBroadCastValue1(self): def test_backward_broadcastvalue1(self):
paddle.disable_static() paddle.disable_static()
for place in self.place: for place in self.place:
paddle.device.set_device(place) paddle.device.set_device(place)
...@@ -809,7 +846,7 @@ class TestIndexPutAPIBackward(unittest.TestCase): ...@@ -809,7 +846,7 @@ class TestIndexPutAPIBackward(unittest.TestCase):
atol=1e-7, atol=1e-7,
) )
def test_backwardBroadCastValue2(self): def test_backward_broadcastvalue2(self):
paddle.disable_static() paddle.disable_static()
for place in self.place: for place in self.place:
paddle.device.set_device(place) paddle.device.set_device(place)
...@@ -854,6 +891,50 @@ class TestIndexPutAPIBackward(unittest.TestCase): ...@@ -854,6 +891,50 @@ class TestIndexPutAPIBackward(unittest.TestCase):
atol=1e-7, atol=1e-7,
) )
def test_backward_all_false_bool_indice(self):
paddle.disable_static()
for place in self.place:
paddle.device.set_device(place)
value = paddle.ones(shape=[2, 1], dtype=paddle.float64)
x = paddle.ones(shape=[16, 21], dtype=paddle.float64)
ix = paddle.zeros(shape=[16, 21], dtype=paddle.bool)
value.stop_gradient = False
x.stop_gradient = False
out = paddle.index_put(x, (ix,), value, False)
dx, dvalue = paddle.grad(
outputs=[out],
inputs=[x, value],
create_graph=False,
retain_graph=True,
)
ref_dx = np.ones(shape=[16, 21], dtype=np.float64)
np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7)
np.testing.assert_allclose(
np.array([[0.0], [0.0]], dtype=np.float64),
dvalue.numpy(),
atol=1e-7,
)
out = paddle.index_put(x, (ix,), value, True)
dx, dvalue = paddle.grad(
outputs=[out],
inputs=[x, value],
create_graph=False,
retain_graph=True,
)
ref_dx = np.ones(shape=[16, 21], dtype=np.float64)
np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7)
np.testing.assert_allclose(
np.array([[0.0], [0.0]], dtype=np.float64),
dvalue.numpy(),
atol=1e-7,
)
def test_backward_in_static(self): def test_backward_in_static(self):
paddle.enable_static() paddle.enable_static()
exe = paddle.static.Executor() exe = paddle.static.Executor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册