未验证 提交 c737f0ae 编写于 作者: 傅剑寒 提交者: GitHub

add all false bool indices support for index_put (#55655)

上级 7da1ffbe
......@@ -16,6 +16,7 @@
#include <numeric>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
......@@ -188,6 +189,19 @@ void IndexPutGradKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
if (int_indices_v.empty()) {
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
}
if (value_grad) {
FullKernel<T, Context>(dev_ctx,
phi::vectorize(value_grad->dims()),
0.0f,
value_grad->dtype(),
value_grad);
}
return;
}
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
std::vector<int64_t> res_dim_v(phi::vectorize(bd_dim));
......
......@@ -117,6 +117,12 @@ void IndexPutKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
if (int_indices_v.empty()) {
if (!out->initialized()) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
return;
}
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
......
......@@ -88,6 +88,11 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
nonzero_indices.Resize(phi::make_ddim({-1, rank}));
NonZeroKernel<bool, Context>(dev_ctx, *indices_v[i], &nonzero_indices);
if (nonzero_indices.numel() == 0) {
std::vector<const phi::DenseTensor*> empty_indices;
return empty_indices;
}
std::vector<phi::DenseTensor*> integer_indices(rank, nullptr);
const int tmp_ix = tmp_indices_v->size();
for (int i = 0; i < rank; ++i) {
......
......@@ -18,6 +18,7 @@
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
......@@ -219,6 +220,20 @@ void IndexPutGradKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
if (int_indices_v.empty()) {
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
}
if (value_grad) {
FullKernel<T, Context>(dev_ctx,
phi::vectorize(value_grad->dims()),
0.0f,
value_grad->dtype(),
value_grad);
}
return;
}
const size_t total_dims = x.dims().size();
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
......
......@@ -118,6 +118,12 @@ void IndexPutKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
if (int_indices_v.empty()) {
if (!out->initialized()) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
return;
}
const size_t total_dims = x.dims().size();
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
......
......@@ -47,14 +47,15 @@ def has_duplicate_index(indices, shapes):
return True
def gen_indices_np(x_shape, indices_shapes, index_type):
def gen_indices_np(x_shape, indices_shapes, index_type, is_all_false):
indices = []
if index_type == np.bool_:
indice = np.zeros(indices_shapes[0], dtype=np.bool_)
indice.flatten()
for i in range(len(indice)):
indice[i] = (i & 1) == 0
indice = indice.reshape(indices_shapes[0])
if not is_all_false:
indice.flatten()
for i in range(len(indice)):
indice[i] = (i & 1) == 0
indice = indice.reshape(indices_shapes[0])
indices.append(indice)
else:
while True:
......@@ -78,6 +79,7 @@ def gen_indices_np(x_shape, indices_shapes, index_type):
class TestIndexPutAPIBase(unittest.TestCase):
def setUp(self):
self.mixed_indices = False
self.is_all_false = False
self.init_dtype_type()
self.setPlace()
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np)
......@@ -85,17 +87,26 @@ class TestIndexPutAPIBase(unittest.TestCase):
if self.mixed_indices:
tmp_indices_np1 = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np
self.x_shape,
self.indices_shapes,
self.index_type_np,
self.is_all_false,
)
tmp_indices_np2 = gen_indices_np(
self.x_shape, self.indices_shapes1, self.index_type_np1
self.x_shape,
self.indices_shapes1,
self.index_type_np1,
self.is_all_false,
)
self.indices_np = tuple(
list(tmp_indices_np1) + list(tmp_indices_np2)
)
else:
self.indices_np = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np
self.x_shape,
self.indices_shapes,
self.index_type_np,
self.is_all_false,
)
def init_dtype_type(self):
......@@ -565,6 +576,32 @@ class TestIndexPutAPI30(TestIndexPutAPIBase):
self.accumulate = True
class TestIndexPutAPI31(TestIndexPutAPIBase):
def init_dtype_type(self):
self.dtype_np = np.bool_
self.index_type_np = np.int32
self.x_shape = (100, 110)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.dtype_pd = paddle.bool
self.index_type_pd = paddle.int32
self.accumulate = False
self.is_all_false = True
class TestIndexPutAPI32(TestIndexPutAPIBase):
def init_dtype_type(self):
self.dtype_np = np.bool_
self.index_type_np = np.int32
self.x_shape = (100, 110)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.dtype_pd = paddle.bool
self.index_type_pd = paddle.int32
self.accumulate = True
self.is_all_false = True
class TestIndexPutInplaceAPI(unittest.TestCase):
def setUp(self):
self.init_dtype_type()
......@@ -572,7 +609,7 @@ class TestIndexPutInplaceAPI(unittest.TestCase):
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np)
self.value_np = np.random.random(self.value_shape).astype(self.dtype_np)
self.indices_np = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np
self.x_shape, self.indices_shapes, self.index_type_np, False
)
def init_dtype_type(self):
......@@ -678,7 +715,7 @@ class TestIndexPutAPIBackward(unittest.TestCase):
atol=1e-7,
)
def test_backwardScalarVal(self):
def test_backward_scalarval(self):
paddle.disable_static()
for place in self.place:
paddle.device.set_device(place)
......@@ -719,7 +756,7 @@ class TestIndexPutAPIBackward(unittest.TestCase):
np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7
)
def test_backwardBroadCastValue(self):
def test_backward_broadcastvalue(self):
paddle.disable_static()
for place in self.place:
paddle.device.set_device(place)
......@@ -764,7 +801,7 @@ class TestIndexPutAPIBackward(unittest.TestCase):
atol=1e-7,
)
def test_backwardBroadCastValue1(self):
def test_backward_broadcastvalue1(self):
paddle.disable_static()
for place in self.place:
paddle.device.set_device(place)
......@@ -809,7 +846,7 @@ class TestIndexPutAPIBackward(unittest.TestCase):
atol=1e-7,
)
def test_backwardBroadCastValue2(self):
def test_backward_broadcastvalue2(self):
paddle.disable_static()
for place in self.place:
paddle.device.set_device(place)
......@@ -854,6 +891,50 @@ class TestIndexPutAPIBackward(unittest.TestCase):
atol=1e-7,
)
def test_backward_all_false_bool_indice(self):
paddle.disable_static()
for place in self.place:
paddle.device.set_device(place)
value = paddle.ones(shape=[2, 1], dtype=paddle.float64)
x = paddle.ones(shape=[16, 21], dtype=paddle.float64)
ix = paddle.zeros(shape=[16, 21], dtype=paddle.bool)
value.stop_gradient = False
x.stop_gradient = False
out = paddle.index_put(x, (ix,), value, False)
dx, dvalue = paddle.grad(
outputs=[out],
inputs=[x, value],
create_graph=False,
retain_graph=True,
)
ref_dx = np.ones(shape=[16, 21], dtype=np.float64)
np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7)
np.testing.assert_allclose(
np.array([[0.0], [0.0]], dtype=np.float64),
dvalue.numpy(),
atol=1e-7,
)
out = paddle.index_put(x, (ix,), value, True)
dx, dvalue = paddle.grad(
outputs=[out],
inputs=[x, value],
create_graph=False,
retain_graph=True,
)
ref_dx = np.ones(shape=[16, 21], dtype=np.float64)
np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7)
np.testing.assert_allclose(
np.array([[0.0], [0.0]], dtype=np.float64),
dvalue.numpy(),
atol=1e-7,
)
def test_backward_in_static(self):
paddle.enable_static()
exe = paddle.static.Executor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册