diff --git a/paddle/fluid/operators/masked_select_op_mlu.cc b/paddle/fluid/operators/masked_select_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..279096b762ca8738cbc9352dc7341ab0725cf049 --- /dev/null +++ b/paddle/fluid/operators/masked_select_op_mlu.cc @@ -0,0 +1,204 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class MaskedSelectedMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input = ctx.Input("X"); + auto mask = ctx.Input("Mask"); + auto out = ctx.Output("Y"); + + auto input_dim = input->dims(); + auto mask_dim = mask->dims(); + PADDLE_ENFORCE_EQ( + input_dim, + mask_dim, + platform::errors::InvalidArgument( + "The dim size of input and mask in OP(masked_selected) " + "must be equal, but got input dim:(%ld), mask dim: " + "(%ld). Please check input " + "value.", + input_dim, + mask_dim)); + + Tensor number(framework::TransToPhiDataType(VT::INT32)); + void* number_ptr = number.mutable_data({1}, ctx.GetPlace()); + + out->Resize(mask->dims()); + out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc mask_desc(*mask); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::Mask(ctx, + CNNL_MASKED_SELECT, + input_desc.get(), + GetBasePtr(input), + mask_desc.get(), + GetBasePtr(mask), + nullptr, + nullptr, + out_desc.get(), + GetBasePtr(out), + static_cast(number_ptr)); + } +}; + +template +class MaskedSelectedGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto mask = ctx.Input("Mask"); + auto y_grad = ctx.Input(framework::GradVarName("Y")); + auto x_grad = ctx.Output(framework::GradVarName("X")); + + auto& dev_ctx = + ctx.template device_context(); + Tensor mask_int32, out_size; + std::vector out_size_vec; + mask_int32.mutable_data(mask->dims(), ctx.GetPlace()); + out_size.mutable_data({1}, ctx.GetPlace()); + + MLUCnnlTensorDesc mask_desc(*mask); + MLUCnnlTensorDesc mask_int32_desc(mask_int32); + MLUCnnlTensorDesc out_size_desc(out_size); + auto cast_type = GetCastDataType(mask->dtype(), DataType::INT32); + MLUCnnl::Cast(ctx, + cast_type, + mask_desc.get(), + GetBasePtr(mask), + mask_int32_desc.get(), + GetBasePtr(&mask_int32)); + + auto mask_int32_dim = phi::vectorize(mask_int32.dims()); + std::vector reduce_dims; + for (size_t i = 0; i < mask_int32_dim.size(); i++) { + reduce_dims.push_back(static_cast(i)); + } + + std::string reduce_name = "reduce_sum"; + cnnlReduceOp_t reduce_op = GetMLUCnnlReduceOp(reduce_name); + MLUCnnlReduceDesc reduce_desc(reduce_dims, + reduce_op, + ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, + CNNL_32BIT_INDICES); + + MLUCnnl::Reduce(ctx, + true, + reduce_desc.get(), + nullptr, + mask_int32_desc.get(), + GetBasePtr(&mask_int32), + 0, + nullptr, + nullptr, + out_size_desc.get(), + GetBasePtr(&out_size)); + + paddle::framework::TensorToVector(out_size, dev_ctx, &out_size_vec); + dev_ctx.Wait(); + + Tensor mask_int32_tmp; + mask_int32_tmp.ShareDataWith(mask_int32); + mask_int32_tmp.Resize({mask_int32.numel()}); + Tensor topk_v2_out(framework::TransToPhiDataType(VT::INT32)), + indices_int32(framework::TransToPhiDataType(VT::INT32)); + topk_v2_out.mutable_data({mask_int32.numel()}, ctx.GetPlace()); + indices_int32.mutable_data({mask_int32.numel()}, ctx.GetPlace()); + + MLUCnnlTensorDesc topk_v2_out_desc(topk_v2_out); + MLUCnnlTensorDesc indices_int32_desc(indices_int32); + MLUCnnlTensorDesc mask_int32_tmp_desc(mask_int32_tmp); + + const int dim = 0; + MLUCnnl::TopK(ctx, + mask_int32.numel(), + dim, + true, + false, + mask_int32_tmp_desc.get(), + GetBasePtr(&mask_int32_tmp), + topk_v2_out_desc.get(), + GetBasePtr(&topk_v2_out), + indices_int32_desc.get(), + GetBasePtr(&indices_int32)); + + auto stream = ctx.template device_context().stream(); + + Tensor indices_int32_out; + indices_int32_out.mutable_data({out_size_vec[0]}, ctx.GetPlace()); + memory::Copy(ctx.GetPlace(), + GetBasePtr(&indices_int32_out), + ctx.GetPlace(), + GetBasePtr(&indices_int32), + out_size_vec[0] * sizeof(int32_t), + stream); + + Tensor y_grad_tmp_out; + y_grad_tmp_out.mutable_data({out_size_vec[0]}, ctx.GetPlace()); + MLUCnnlTensorDesc y_grad_tmp_out_desc(y_grad_tmp_out); + memory::Copy(ctx.GetPlace(), + GetBasePtr(&y_grad_tmp_out), + ctx.GetPlace(), + GetBasePtr(y_grad), + out_size_vec[0] * sizeof(T), + stream); + + Tensor indices_int32_tmp; + indices_int32_tmp.ShareDataWith(indices_int32_out); + indices_int32_tmp.Resize({out_size_vec[0], 1}); + MLUCnnlTensorDesc indices_int32_tmp_desc(indices_int32_tmp); + + const cnnlScatterNdMode_t mode = CNNL_SCATTERND_UPDATE; + x_grad->Resize({x_grad->numel()}); + x_grad->mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc x_grad_desc(*x_grad); + MLUCnnl::ScatterNd(ctx, + mode, + indices_int32_tmp_desc.get(), + GetBasePtr(&indices_int32_tmp), + y_grad_tmp_out_desc.get(), + GetBasePtr(&y_grad_tmp_out), + nullptr, + nullptr, + x_grad_desc.get(), + GetBasePtr(x_grad)); + x_grad->Resize(mask->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(masked_select, + ops::MaskedSelectedMLUKernel, + ops::MaskedSelectedMLUKernel, + ops::MaskedSelectedMLUKernel); + +REGISTER_OP_MLU_KERNEL(masked_select_grad, + ops::MaskedSelectedGradMLUKernel, + ops::MaskedSelectedGradMLUKernel, + ops::MaskedSelectedGradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index c0619145ad5ab336de5541d3cbe74226d9661b86..972bdefdf02b84d5ae7072cd056cf1a05e059a60 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -2597,6 +2597,19 @@ MLURNNDesc::~MLURNNDesc() { cnnlSign(handle, input_desc, input, output_desc, output)); } +/* static */ void MLUCnnl::IndexSelect(const ExecutionContext& ctx, + const int dim, + cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t index_desc, + const void* index, + const cnnlTensorDescriptor_t output_desc, + void* output) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlIndexSelect( + handle, dim, input_desc, input, index_desc, index, output_desc, output)); +} + /* static */ void MLUCnnl::IsFinite(const ExecutionContext& ctx, const cnnlTensorDescriptor_t input_desc, const void* input, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 9031040ec55984ddd7e76ce7b2d8a2e2dd6c15f4..85f4439c3b9743ebc71e72d5d4be1cbd32175a4f 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1391,6 +1391,15 @@ class MLUCnnl { const cnnlTensorDescriptor_t output_desc, void* output); + static void IndexSelect(const ExecutionContext& ctx, + const int dim, + cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t index_desc, + const void* index, + const cnnlTensorDescriptor_t output_desc, + void* output); + static void IsFinite(const ExecutionContext& ctx, const cnnlTensorDescriptor_t input_desc, const void* input, diff --git a/python/paddle/fluid/tests/unittests/mlu/test_masked_select_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_masked_select_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..7efed0ea4b0f870bbf12b009dd49d7a1c5e0ba8b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_masked_select_op_mlu.py @@ -0,0 +1,169 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import sys + +sys.path.append("..") +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid as fluid +import paddle + +paddle.enable_static() + + +def np_masked_select(shape, x, mask): + result = np.empty(shape=(0), dtype=x.dtype) + sum = 0 + for index, (ele, ma) in enumerate(zip(np.nditer(x), np.nditer(mask))): + if ma: + sum = sum + 1 + result = np.append(result, ele) + for index, (ele, ma) in enumerate(zip(np.nditer(x), np.nditer(mask))): + if index >= sum: + result = np.append(result, 0) + result = np.reshape(result, shape) + return result + + +class TestMaskedSelectOp(OpTest): + + def setUp(self): + self.init() + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + self.op_type = "masked_select" + self.python_api = paddle.masked_select + x = np.random.random(self.shape).astype('float32') + mask = np.array(np.random.randint(2, size=self.shape, dtype=bool)) + out = np_masked_select(self.shape, x, mask) + self.inputs = {'X': x, 'Mask': mask} + self.outputs = {'Y': out} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Y') + + def init(self): + self.shape = (50, 3) + + +class TestMaskedSelectOp1(TestMaskedSelectOp): + + def init(self): + self.shape = (6, 8, 9, 18) + + +class TestMaskedSelectOp2(TestMaskedSelectOp): + + def init(self): + self.shape = (168, ) + + +@skip_check_grad_ci(reason="get_numeric_gradient not support int32") +class TestMaskedSelectOpInt32(TestMaskedSelectOp): + + def init_dtype(self): + self.dtype = np.int32 + + def test_check_grad(self): + pass + + +class TestMaskedSelectOpFp16(TestMaskedSelectOp): + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_grad(self): + x_grad = self.inputs['Mask'].astype(self.dtype) + x_grad = x_grad * (1 / x_grad.size) + self.check_grad_with_place(self.place, ['X'], + 'Y', + user_defined_grads=[x_grad]) + + +class TestMaskedSelectAPI(unittest.TestCase): + + def test_imperative_mode(self): + paddle.disable_static() + shape = (88, 6, 8) + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + x = paddle.to_tensor(np_x) + mask = paddle.to_tensor(np_mask) + out = paddle.masked_select(x, mask) + np_out = np_masked_select(shape, np_x, np_mask) + self.assertEqual(np.allclose(out.numpy(), np_out), True) + paddle.enable_static() + + def test_static_mode(self): + shape = [8, 9, 6] + x = paddle.fluid.data(shape=shape, dtype='float32', name='x') + mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask') + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + + out = paddle.masked_select(x, mask) + np_out = np_masked_select(shape, np_x, np_mask) + + exe = paddle.static.Executor(place=paddle.device.MLUPlace(0)) + + res = exe.run(paddle.static.default_main_program(), + feed={ + "x": np_x, + "mask": np_mask + }, + fetch_list=[out]) + self.assertEqual(np.allclose(res, np_out), True) + + +class TestMaskedSelectError(unittest.TestCase): + + def test_error(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + shape = [8, 9, 6] + x = paddle.fluid.data(shape=shape, dtype='float32', name='x') + mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask') + mask_float = paddle.fluid.data(shape=shape, + dtype='float32', + name='mask_float') + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + + def test_x_type(): + paddle.masked_select(np_x, mask) + + self.assertRaises(TypeError, test_x_type) + + def test_mask_type(): + paddle.masked_select(x, np_mask) + + self.assertRaises(TypeError, test_mask_type) + + def test_mask_dtype(): + paddle.masked_select(x, mask_float) + + self.assertRaises(TypeError, test_mask_dtype) + + +if __name__ == '__main__': + unittest.main()