From 413d1abf30be3ec17bafd94d13aa0a9e1dd9445e Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 5 Jul 2023 15:38:09 +0800 Subject: [PATCH] masked select x and mask support broadcast (#54776) * masked select forward support broadcast * cpu forward and backward * gpu support mask broadcast * fix comment * x support broadcast * fix comment --- .../kernels/cpu/masked_select_grad_kernel.cc | 60 +++++++++++++- .../phi/kernels/cpu/masked_select_kernel.cc | 33 ++++++-- .../kernels/gpu/masked_select_grad_kernel.cu | 48 ++++++++++- .../phi/kernels/gpu/masked_select_kernel.cu | 32 ++++++-- test/legacy_test/test_masked_select_op.py | 80 ++++++++++++++++++- 5 files changed, 234 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/cpu/masked_select_grad_kernel.cc b/paddle/phi/kernels/cpu/masked_select_grad_kernel.cc index 8b1ca87b4a2..09e4b80c859 100644 --- a/paddle/phi/kernels/cpu/masked_select_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/masked_select_grad_kernel.cc @@ -14,22 +14,57 @@ #include "paddle/phi/kernels/masked_select_grad_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/expand_grad_kernel.h" +#include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { template void MaskedSelectGradKernel(const Context& dev_ctx, - const DenseTensor& x UNUSED, + const DenseTensor& x, const DenseTensor& mask, const DenseTensor& out_grad, DenseTensor* x_grad) { - auto* mask_data = mask.data(); - auto* input_data = out_grad.data(); + // x_grad.size() == x.size() + // x.size() == mask.size(), no broadcast, expand_mask = false, expand_x = + // false x.size() < mask.size(), x broadcast to mask, expand_mask = false, + // expand_x = true x.size() > mask.size(), mask broadcast to x, epxand_mask = + // true, expand_x = false + DenseTensor mask_expand; + DenseTensor x_grad_expand; + bool expand_x = false; + + auto expanded_size = funcs::MatrixGetBroadcastBatchPortion( + vectorize(x_grad->dims()), vectorize(mask.dims())); + auto expaned_dims = make_ddim(expanded_size); + + if (mask.dims() != expaned_dims) { + ExpandKernel( + dev_ctx, mask, IntArray(expanded_size), &mask_expand); + } else { + mask_expand = mask; + } + + if (x_grad->dims() != expaned_dims) { + x_grad_expand = Empty(dev_ctx, IntArray(expanded_size)); + expand_x = true; + } else { + expand_x = false; + } auto* out_data = dev_ctx.template Alloc(x_grad); - int mask_size = mask.numel(); + if (expand_x) { + out_data = x_grad_expand.data(); + } + + auto* mask_data = mask_expand.data(); + auto* input_data = out_grad.data(); + int mask_size = mask_expand.numel(); int index = 0; for (int i = 0; i < mask_size; i++) { @@ -40,6 +75,23 @@ void MaskedSelectGradKernel(const Context& dev_ctx, out_data[i] = 0; } } + + auto out_grad_numel = out_grad.numel(); + PADDLE_ENFORCE_EQ( + index, + out_grad_numel, + phi::errors::InvalidArgument( + "The dim size of input and x_grad in OP(masked_selected_grad) " + "must be equal, but got mask with ones:(%ld), out_grad numel: " + "(%ld). Please check input " + "value.", + index, + out_grad_numel)); + + if (expand_x) { + ExpandGradKernel( + dev_ctx, x, x_grad_expand, IntArray(expanded_size), x_grad); + } } } // namespace phi diff --git a/paddle/phi/kernels/cpu/masked_select_kernel.cc b/paddle/phi/kernels/cpu/masked_select_kernel.cc index 33311c26cfe..837a8921e81 100644 --- a/paddle/phi/kernels/cpu/masked_select_kernel.cc +++ b/paddle/phi/kernels/cpu/masked_select_kernel.cc @@ -13,9 +13,10 @@ // limitations under the License. #include "paddle/phi/kernels/masked_select_kernel.h" - #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { @@ -24,13 +25,28 @@ void MaskedSelectKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& mask, DenseTensor* out) { - auto* mask_data = mask.data(); - auto input_data = x.data(); + DenseTensor mask_expand; + DenseTensor x_expand; + + auto expanded_size = funcs::MatrixGetBroadcastBatchPortion( + vectorize(x.dims()), vectorize(mask.dims())); - auto mask_size = mask.numel(); + DDim epxand_dims = make_ddim(expanded_size); + if (mask.dims() != epxand_dims) { + ExpandKernel( + dev_ctx, mask, IntArray(expanded_size), &mask_expand); + } else { + mask_expand = mask; + } - auto input_dim = x.dims(); - auto mask_dim = mask.dims(); + if (x.dims() != epxand_dims) { + ExpandKernel(dev_ctx, x, IntArray(expanded_size), &x_expand); + } else { + x_expand = x; + } + + auto input_dim = x_expand.dims(); + auto mask_dim = mask_expand.dims(); PADDLE_ENFORCE_EQ(input_dim, mask_dim, phi::errors::InvalidArgument( @@ -41,6 +57,11 @@ void MaskedSelectKernel(const Context& dev_ctx, input_dim, mask_dim)); + auto input_data = x_expand.data(); + auto mask_data = mask_expand.data(); + + auto mask_size = mask_expand.numel(); + int out_size = 0; for (int i = 0; i < mask_size; i++) { if (mask_data[i]) out_size++; diff --git a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu index 1121ff361f8..983fdb26564 100644 --- a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu @@ -21,6 +21,11 @@ #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/expand_grad_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h" namespace phi { @@ -50,12 +55,51 @@ void MaskedSelectGradKernel(const Context& dev_ctx, const DenseTensor& mask, const DenseTensor& out_grad, DenseTensor* x_grad) { - auto mask_size = mask.numel(); + // x_grad.size() == x.size() + // x.size() == mask.size(), no broadcast, expand_mask = false, expand_x = + // false x.size() < mask.size(), x broadcast to mask, expand_mask = false, + // expand_x = true x.size() > mask.size(), mask broadcast to x, epxand_mask = + // true, expand_x = false + DenseTensor mask_expand; + DenseTensor x_grad_expand; + bool expand_x = false; + + auto expanded_size = funcs::MatrixGetBroadcastBatchPortion( + vectorize(x_grad->dims()), vectorize(mask.dims())); + auto expaned_dims = make_ddim(expanded_size); + + if (mask.dims() != expaned_dims) { + ExpandKernel( + dev_ctx, mask, IntArray(expanded_size), &mask_expand); + } else { + mask_expand = mask; + } + + if (x_grad->dims() != expaned_dims) { + x_grad_expand = Empty(dev_ctx, IntArray(expanded_size)); + expand_x = true; + } else { + expand_x = false; + } + dev_ctx.template Alloc(x_grad); + auto mask_size = mask_expand.numel(); if (mask_size <= 0) return; + using Functor = MaskedSelectGradFunctor; + + DenseTensor* x_grad_tmp = x_grad; + if (expand_x) { + x_grad_tmp = &x_grad_expand; + } + phi::funcs::SelectKernel( - dev_ctx, mask, out_grad, x_grad, Functor()); + dev_ctx, mask_expand, out_grad, x_grad_tmp, Functor()); + + if (expand_x) { + ExpandGradKernel( + dev_ctx, x, x_grad_expand, IntArray(expanded_size), x_grad); + } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/masked_select_kernel.cu b/paddle/phi/kernels/gpu/masked_select_kernel.cu index 208bdd853cc..89cb714d78d 100644 --- a/paddle/phi/kernels/gpu/masked_select_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_kernel.cu @@ -22,6 +22,8 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h" namespace phi { @@ -48,12 +50,29 @@ void MaskedSelectKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& mask, DenseTensor* out) { - auto* mask_data = mask.data(); - auto input_data = x.data(); + DenseTensor mask_expand; + DenseTensor x_expand; - auto mask_size = mask.numel(); - auto input_dim = x.dims(); - auto mask_dim = mask.dims(); + auto expanded_size = funcs::MatrixGetBroadcastBatchPortion( + vectorize(x.dims()), vectorize(mask.dims())); + + DDim epxand_dims = make_ddim(expanded_size); + if (mask.dims() != epxand_dims) { + phi::ExpandKernel( + dev_ctx, mask, IntArray(expanded_size), &mask_expand); + } else { + mask_expand = mask; + } + + if (x.dims() != epxand_dims) { + phi::ExpandKernel( + dev_ctx, x, IntArray(expanded_size), &x_expand); + } else { + x_expand = x; + } + + auto input_dim = x_expand.dims(); + auto mask_dim = mask_expand.dims(); PADDLE_ENFORCE_EQ(input_dim, mask_dim, phi::errors::InvalidArgument( @@ -63,9 +82,10 @@ void MaskedSelectKernel(const Context& dev_ctx, "value.", input_dim, mask_dim)); + using Functor = MaskedSelectFunctor; phi::funcs::SelectKernel( - dev_ctx, mask, x, out, Functor()); + dev_ctx, mask_expand, x_expand, out, Functor()); } } // namespace phi diff --git a/test/legacy_test/test_masked_select_op.py b/test/legacy_test/test_masked_select_op.py index cc402947e6c..6d89ced3ecc 100644 --- a/test/legacy_test/test_masked_select_op.py +++ b/test/legacy_test/test_masked_select_op.py @@ -23,6 +23,7 @@ from paddle.fluid import core def np_masked_select(x, mask): result = np.empty(shape=(0), dtype=x.dtype) + x, mask = np.broadcast_arrays(x, mask) for ele, ma in zip(np.nditer(x), np.nditer(mask)): if ma: result = np.append(result, ele) @@ -35,7 +36,7 @@ class TestMaskedSelectOp(OpTest): self.op_type = "masked_select" self.python_api = paddle.masked_select x = np.random.random(self.shape).astype("float64") - mask = np.array(np.random.randint(2, size=self.shape, dtype=bool)) + mask = np.array(np.random.randint(2, size=self.mask_shape, dtype=bool)) out = np_masked_select(x, mask) self.inputs = {'X': x, 'Mask': mask} self.outputs = {'Y': out} @@ -48,16 +49,19 @@ class TestMaskedSelectOp(OpTest): def init(self): self.shape = (50, 3) + self.mask_shape = self.shape class TestMaskedSelectOp1(TestMaskedSelectOp): def init(self): self.shape = (6, 8, 9, 18) + self.mask_shape = self.shape class TestMaskedSelectOp2(TestMaskedSelectOp): def init(self): self.shape = (168,) + self.mask_shape = self.shape class TestMaskedSelectFP16Op(OpTest): @@ -163,6 +167,9 @@ class TestMaskedSelectAPI(unittest.TestCase): class TestMaskedSelectError(unittest.TestCase): + def setUp(self): + paddle.enable_static() + def test_error(self): with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() @@ -192,6 +199,77 @@ class TestMaskedSelectError(unittest.TestCase): self.assertRaises(TypeError, test_mask_dtype) +class TestMaskedSelectBroadcast(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_broadcast(self): + shape = (3, 4) + np_x = np.random.random(shape).astype('float32') + np_mask = np.array([[True], [False], [False]]) + x = paddle.to_tensor(np_x) + mask = paddle.to_tensor(np_mask) + out = paddle.masked_select(x, mask) + np_out = np_x[0] + np.testing.assert_allclose(out.numpy(), np_out, rtol=1e-05) + + def test_broadcast_grad(self): + shape = (3, 4) + np_x = np.random.random(shape).astype('float32') + np_mask = np.array([[True], [False], [False]]) + x = paddle.to_tensor(np_x, stop_gradient=False) + mask = paddle.to_tensor(np_mask) + out = paddle.masked_select(x, mask) + out.sum().backward() + np_out = np.zeros(shape) + np_out[0] = 1.0 + np.testing.assert_allclose(x.grad.numpy(), np_out, rtol=1e-05) + + def test_broadcast_zerodim(self): + shape = (3, 4) + np_x = np.random.random(shape).astype('float32') + x = paddle.to_tensor(np_x) + mask = paddle.to_tensor(True) + out = paddle.masked_select(x, mask) + np_out = np_x.reshape(-1) + np.testing.assert_allclose(out.numpy(), np_out, rtol=1e-05) + + def test_broadcast_zerodim_grad(self): + shape = (3, 4) + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(True) + x = paddle.to_tensor(np_x, stop_gradient=False) + mask = paddle.to_tensor(np_mask) + out = paddle.masked_select(x, mask) + out.sum().backward() + np_out = np.ones(shape) + np.testing.assert_allclose(x.grad.numpy(), np_out, rtol=1e-05) + + +class TestMaskedSelectOpBroadcast(TestMaskedSelectOp): + def init(self): + self.shape = (3, 40) + self.mask_shape = (3, 1) + + +class TestMaskedSelectOpBroadcast2(TestMaskedSelectOp): + def init(self): + self.shape = (300, 1) + self.mask_shape = (300, 40) + + +class TestMaskedSelectOpBroadcast3(TestMaskedSelectOp): + def init(self): + self.shape = (120,) + self.mask_shape = (300, 120) + + +class TestMaskedSelectOpBroadcast4(TestMaskedSelectOp): + def init(self): + self.shape = (300, 40) + self.mask_shape = 40 + + if __name__ == '__main__': paddle.enable_static() unittest.main() -- GitLab