未验证 提交 413d1abf 编写于 作者: H Hui Zhang 提交者: GitHub

masked select x and mask support broadcast (#54776)

* masked select forward support broadcast

* cpu forward and backward

* gpu support mask broadcast

* fix comment

* x support broadcast

* fix comment
上级 567dabeb
...@@ -14,22 +14,57 @@ ...@@ -14,22 +14,57 @@
#include "paddle/phi/kernels/masked_select_grad_kernel.h" #include "paddle/phi/kernels/masked_select_grad_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_grad_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx, void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& x UNUSED, const DenseTensor& x,
const DenseTensor& mask, const DenseTensor& mask,
const DenseTensor& out_grad, const DenseTensor& out_grad,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto* mask_data = mask.data<bool>(); // x_grad.size() == x.size()
auto* input_data = out_grad.data<T>(); // x.size() == mask.size(), no broadcast, expand_mask = false, expand_x =
// false x.size() < mask.size(), x broadcast to mask, expand_mask = false,
// expand_x = true x.size() > mask.size(), mask broadcast to x, epxand_mask =
// true, expand_x = false
DenseTensor mask_expand;
DenseTensor x_grad_expand;
bool expand_x = false;
auto expanded_size = funcs::MatrixGetBroadcastBatchPortion(
vectorize(x_grad->dims()), vectorize(mask.dims()));
auto expaned_dims = make_ddim(expanded_size);
if (mask.dims() != expaned_dims) {
ExpandKernel<bool, Context>(
dev_ctx, mask, IntArray(expanded_size), &mask_expand);
} else {
mask_expand = mask;
}
if (x_grad->dims() != expaned_dims) {
x_grad_expand = Empty<T, Context>(dev_ctx, IntArray(expanded_size));
expand_x = true;
} else {
expand_x = false;
}
auto* out_data = dev_ctx.template Alloc<T>(x_grad); auto* out_data = dev_ctx.template Alloc<T>(x_grad);
int mask_size = mask.numel(); if (expand_x) {
out_data = x_grad_expand.data<T>();
}
auto* mask_data = mask_expand.data<bool>();
auto* input_data = out_grad.data<T>();
int mask_size = mask_expand.numel();
int index = 0; int index = 0;
for (int i = 0; i < mask_size; i++) { for (int i = 0; i < mask_size; i++) {
...@@ -40,6 +75,23 @@ void MaskedSelectGradKernel(const Context& dev_ctx, ...@@ -40,6 +75,23 @@ void MaskedSelectGradKernel(const Context& dev_ctx,
out_data[i] = 0; out_data[i] = 0;
} }
} }
auto out_grad_numel = out_grad.numel();
PADDLE_ENFORCE_EQ(
index,
out_grad_numel,
phi::errors::InvalidArgument(
"The dim size of input and x_grad in OP(masked_selected_grad) "
"must be equal, but got mask with ones:(%ld), out_grad numel: "
"(%ld). Please check input "
"value.",
index,
out_grad_numel));
if (expand_x) {
ExpandGradKernel<T, Context>(
dev_ctx, x, x_grad_expand, IntArray(expanded_size), x_grad);
}
} }
} // namespace phi } // namespace phi
......
...@@ -13,9 +13,10 @@ ...@@ -13,9 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/masked_select_kernel.h" #include "paddle/phi/kernels/masked_select_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace phi { namespace phi {
...@@ -24,13 +25,28 @@ void MaskedSelectKernel(const Context& dev_ctx, ...@@ -24,13 +25,28 @@ void MaskedSelectKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& mask, const DenseTensor& mask,
DenseTensor* out) { DenseTensor* out) {
auto* mask_data = mask.data<bool>(); DenseTensor mask_expand;
auto input_data = x.data<T>(); DenseTensor x_expand;
auto expanded_size = funcs::MatrixGetBroadcastBatchPortion(
vectorize(x.dims()), vectorize(mask.dims()));
auto mask_size = mask.numel(); DDim epxand_dims = make_ddim(expanded_size);
if (mask.dims() != epxand_dims) {
ExpandKernel<bool, Context>(
dev_ctx, mask, IntArray(expanded_size), &mask_expand);
} else {
mask_expand = mask;
}
auto input_dim = x.dims(); if (x.dims() != epxand_dims) {
auto mask_dim = mask.dims(); ExpandKernel<T, Context>(dev_ctx, x, IntArray(expanded_size), &x_expand);
} else {
x_expand = x;
}
auto input_dim = x_expand.dims();
auto mask_dim = mask_expand.dims();
PADDLE_ENFORCE_EQ(input_dim, PADDLE_ENFORCE_EQ(input_dim,
mask_dim, mask_dim,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
...@@ -41,6 +57,11 @@ void MaskedSelectKernel(const Context& dev_ctx, ...@@ -41,6 +57,11 @@ void MaskedSelectKernel(const Context& dev_ctx,
input_dim, input_dim,
mask_dim)); mask_dim));
auto input_data = x_expand.data<T>();
auto mask_data = mask_expand.data<bool>();
auto mask_size = mask_expand.numel();
int out_size = 0; int out_size = 0;
for (int i = 0; i < mask_size; i++) { for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) out_size++; if (mask_data[i]) out_size++;
......
...@@ -21,6 +21,11 @@ ...@@ -21,6 +21,11 @@
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_grad_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h"
namespace phi { namespace phi {
...@@ -50,12 +55,51 @@ void MaskedSelectGradKernel(const Context& dev_ctx, ...@@ -50,12 +55,51 @@ void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& mask, const DenseTensor& mask,
const DenseTensor& out_grad, const DenseTensor& out_grad,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto mask_size = mask.numel(); // x_grad.size() == x.size()
// x.size() == mask.size(), no broadcast, expand_mask = false, expand_x =
// false x.size() < mask.size(), x broadcast to mask, expand_mask = false,
// expand_x = true x.size() > mask.size(), mask broadcast to x, epxand_mask =
// true, expand_x = false
DenseTensor mask_expand;
DenseTensor x_grad_expand;
bool expand_x = false;
auto expanded_size = funcs::MatrixGetBroadcastBatchPortion(
vectorize(x_grad->dims()), vectorize(mask.dims()));
auto expaned_dims = make_ddim(expanded_size);
if (mask.dims() != expaned_dims) {
ExpandKernel<bool, Context>(
dev_ctx, mask, IntArray(expanded_size), &mask_expand);
} else {
mask_expand = mask;
}
if (x_grad->dims() != expaned_dims) {
x_grad_expand = Empty<T, Context>(dev_ctx, IntArray(expanded_size));
expand_x = true;
} else {
expand_x = false;
}
dev_ctx.template Alloc<T>(x_grad); dev_ctx.template Alloc<T>(x_grad);
auto mask_size = mask_expand.numel();
if (mask_size <= 0) return; if (mask_size <= 0) return;
using Functor = MaskedSelectGradFunctor<bool, T, T>; using Functor = MaskedSelectGradFunctor<bool, T, T>;
DenseTensor* x_grad_tmp = x_grad;
if (expand_x) {
x_grad_tmp = &x_grad_expand;
}
phi::funcs::SelectKernel<bool, T, T, 2, Functor>( phi::funcs::SelectKernel<bool, T, T, 2, Functor>(
dev_ctx, mask, out_grad, x_grad, Functor()); dev_ctx, mask_expand, out_grad, x_grad_tmp, Functor());
if (expand_x) {
ExpandGradKernel<T, Context>(
dev_ctx, x, x_grad_expand, IntArray(expanded_size), x_grad);
}
} }
} // namespace phi } // namespace phi
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h"
namespace phi { namespace phi {
...@@ -48,12 +50,29 @@ void MaskedSelectKernel(const Context& dev_ctx, ...@@ -48,12 +50,29 @@ void MaskedSelectKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& mask, const DenseTensor& mask,
DenseTensor* out) { DenseTensor* out) {
auto* mask_data = mask.data<bool>(); DenseTensor mask_expand;
auto input_data = x.data<T>(); DenseTensor x_expand;
auto mask_size = mask.numel(); auto expanded_size = funcs::MatrixGetBroadcastBatchPortion(
auto input_dim = x.dims(); vectorize(x.dims()), vectorize(mask.dims()));
auto mask_dim = mask.dims();
DDim epxand_dims = make_ddim(expanded_size);
if (mask.dims() != epxand_dims) {
phi::ExpandKernel<bool, Context>(
dev_ctx, mask, IntArray(expanded_size), &mask_expand);
} else {
mask_expand = mask;
}
if (x.dims() != epxand_dims) {
phi::ExpandKernel<T, Context>(
dev_ctx, x, IntArray(expanded_size), &x_expand);
} else {
x_expand = x;
}
auto input_dim = x_expand.dims();
auto mask_dim = mask_expand.dims();
PADDLE_ENFORCE_EQ(input_dim, PADDLE_ENFORCE_EQ(input_dim,
mask_dim, mask_dim,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
...@@ -63,9 +82,10 @@ void MaskedSelectKernel(const Context& dev_ctx, ...@@ -63,9 +82,10 @@ void MaskedSelectKernel(const Context& dev_ctx,
"value.", "value.",
input_dim, input_dim,
mask_dim)); mask_dim));
using Functor = MaskedSelectFunctor<bool, T, T>; using Functor = MaskedSelectFunctor<bool, T, T>;
phi::funcs::SelectKernel<bool, T, T, 1, Functor>( phi::funcs::SelectKernel<bool, T, T, 1, Functor>(
dev_ctx, mask, x, out, Functor()); dev_ctx, mask_expand, x_expand, out, Functor());
} }
} // namespace phi } // namespace phi
......
...@@ -23,6 +23,7 @@ from paddle.fluid import core ...@@ -23,6 +23,7 @@ from paddle.fluid import core
def np_masked_select(x, mask): def np_masked_select(x, mask):
result = np.empty(shape=(0), dtype=x.dtype) result = np.empty(shape=(0), dtype=x.dtype)
x, mask = np.broadcast_arrays(x, mask)
for ele, ma in zip(np.nditer(x), np.nditer(mask)): for ele, ma in zip(np.nditer(x), np.nditer(mask)):
if ma: if ma:
result = np.append(result, ele) result = np.append(result, ele)
...@@ -35,7 +36,7 @@ class TestMaskedSelectOp(OpTest): ...@@ -35,7 +36,7 @@ class TestMaskedSelectOp(OpTest):
self.op_type = "masked_select" self.op_type = "masked_select"
self.python_api = paddle.masked_select self.python_api = paddle.masked_select
x = np.random.random(self.shape).astype("float64") x = np.random.random(self.shape).astype("float64")
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool)) mask = np.array(np.random.randint(2, size=self.mask_shape, dtype=bool))
out = np_masked_select(x, mask) out = np_masked_select(x, mask)
self.inputs = {'X': x, 'Mask': mask} self.inputs = {'X': x, 'Mask': mask}
self.outputs = {'Y': out} self.outputs = {'Y': out}
...@@ -48,16 +49,19 @@ class TestMaskedSelectOp(OpTest): ...@@ -48,16 +49,19 @@ class TestMaskedSelectOp(OpTest):
def init(self): def init(self):
self.shape = (50, 3) self.shape = (50, 3)
self.mask_shape = self.shape
class TestMaskedSelectOp1(TestMaskedSelectOp): class TestMaskedSelectOp1(TestMaskedSelectOp):
def init(self): def init(self):
self.shape = (6, 8, 9, 18) self.shape = (6, 8, 9, 18)
self.mask_shape = self.shape
class TestMaskedSelectOp2(TestMaskedSelectOp): class TestMaskedSelectOp2(TestMaskedSelectOp):
def init(self): def init(self):
self.shape = (168,) self.shape = (168,)
self.mask_shape = self.shape
class TestMaskedSelectFP16Op(OpTest): class TestMaskedSelectFP16Op(OpTest):
...@@ -163,6 +167,9 @@ class TestMaskedSelectAPI(unittest.TestCase): ...@@ -163,6 +167,9 @@ class TestMaskedSelectAPI(unittest.TestCase):
class TestMaskedSelectError(unittest.TestCase): class TestMaskedSelectError(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def test_error(self): def test_error(self):
with paddle.static.program_guard( with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program() paddle.static.Program(), paddle.static.Program()
...@@ -192,6 +199,77 @@ class TestMaskedSelectError(unittest.TestCase): ...@@ -192,6 +199,77 @@ class TestMaskedSelectError(unittest.TestCase):
self.assertRaises(TypeError, test_mask_dtype) self.assertRaises(TypeError, test_mask_dtype)
class TestMaskedSelectBroadcast(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_broadcast(self):
shape = (3, 4)
np_x = np.random.random(shape).astype('float32')
np_mask = np.array([[True], [False], [False]])
x = paddle.to_tensor(np_x)
mask = paddle.to_tensor(np_mask)
out = paddle.masked_select(x, mask)
np_out = np_x[0]
np.testing.assert_allclose(out.numpy(), np_out, rtol=1e-05)
def test_broadcast_grad(self):
shape = (3, 4)
np_x = np.random.random(shape).astype('float32')
np_mask = np.array([[True], [False], [False]])
x = paddle.to_tensor(np_x, stop_gradient=False)
mask = paddle.to_tensor(np_mask)
out = paddle.masked_select(x, mask)
out.sum().backward()
np_out = np.zeros(shape)
np_out[0] = 1.0
np.testing.assert_allclose(x.grad.numpy(), np_out, rtol=1e-05)
def test_broadcast_zerodim(self):
shape = (3, 4)
np_x = np.random.random(shape).astype('float32')
x = paddle.to_tensor(np_x)
mask = paddle.to_tensor(True)
out = paddle.masked_select(x, mask)
np_out = np_x.reshape(-1)
np.testing.assert_allclose(out.numpy(), np_out, rtol=1e-05)
def test_broadcast_zerodim_grad(self):
shape = (3, 4)
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(True)
x = paddle.to_tensor(np_x, stop_gradient=False)
mask = paddle.to_tensor(np_mask)
out = paddle.masked_select(x, mask)
out.sum().backward()
np_out = np.ones(shape)
np.testing.assert_allclose(x.grad.numpy(), np_out, rtol=1e-05)
class TestMaskedSelectOpBroadcast(TestMaskedSelectOp):
def init(self):
self.shape = (3, 40)
self.mask_shape = (3, 1)
class TestMaskedSelectOpBroadcast2(TestMaskedSelectOp):
def init(self):
self.shape = (300, 1)
self.mask_shape = (300, 40)
class TestMaskedSelectOpBroadcast3(TestMaskedSelectOp):
def init(self):
self.shape = (120,)
self.mask_shape = (300, 120)
class TestMaskedSelectOpBroadcast4(TestMaskedSelectOp):
def init(self):
self.shape = (300, 40)
self.mask_shape = 40
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册