diff --git a/paddle/fluid/operators/masked_select_op_npu.cc b/paddle/fluid/operators/masked_select_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..9d7e3157fa1f5f13cdb751d9730c1f85cabe94c3 --- /dev/null +++ b/paddle/fluid/operators/masked_select_op_npu.cc @@ -0,0 +1,195 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/masked_select_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class MaskedSelectedNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input = ctx.Input("X"); + auto mask = ctx.Input("Mask"); + auto out = ctx.Output("Y"); + + auto input_dim = input->dims(); + auto mask_dim = mask->dims(); + PADDLE_ENFORCE_EQ( + input_dim, mask_dim, + platform::errors::InvalidArgument( + "The dim size of input and mask in OP(masked_selected) " + "must be equal, but got input dim:(%ld), mask dim: " + "(%ld). Please check input " + "value.", + input_dim, mask_dim)); + + auto& dev_ctx = + ctx.template device_context(); + auto stream = dev_ctx.stream(); + + Tensor mask_int32, out_size; + std::vector out_size_vec; + mask_int32.mutable_data(mask->dims(), ctx.GetPlace()); + out_size.mutable_data({1}, ctx.GetPlace()); + { + const auto& cast_runner = + NpuOpRunner("Cast", {*mask}, {mask_int32}, + {{"dst_type", static_cast(ConvertToNpuDtype( + framework::proto::VarType::INT32))}}); + cast_runner.Run(stream); + + mask_int32.Resize({mask_int32.numel()}); + NpuOpRunner sum_runner; + sum_runner.SetType("ReduceSum"); + sum_runner.AddInput(mask_int32); + sum_runner.AddInput(std::vector({0})); + sum_runner.AddOutput(out_size); + sum_runner.AddAttr("keep_dims", false); + sum_runner.Run(stream); + TensorToVector(out_size, dev_ctx, &out_size_vec); + } + + out->Resize({out_size_vec[0]}); + out->mutable_data(ctx.GetPlace()); + + Tensor topkv2_out, indices; + topkv2_out.mutable_data({out_size_vec[0]}, ctx.GetPlace()); + indices.mutable_data({out_size_vec[0]}, ctx.GetPlace()); + { + NpuOpRunner topkv2_runner; + topkv2_runner.SetType("TopKV2") + .AddInput(mask_int32) + .AddInput(out_size) + .AddOutput(topkv2_out) + .AddOutput(indices) + .AddAttr("sorted", false) + .AddAttr("dim", 0) + .AddAttr("largest", true) + .Run(stream); + // TopKV2 may be unstable + NpuOpRunner topkv2_runner2; + topkv2_runner2.SetType("TopKV2") + .AddInput(indices) + .AddInput(out_size) + .AddOutput(topkv2_out) + .AddOutput(indices) + .AddAttr("sorted", true) + .AddAttr("dim", 0) + .AddAttr("largest", false) + .Run(stream); + + Tensor input_tmp; + input_tmp.ShareDataWith(*input); + input_tmp.Resize({input->numel()}); + const auto& gather_runner = NpuOpRunner( + "GatherV2D", {input_tmp, topkv2_out}, {*out}, {{"axis", 0}}); + gather_runner.Run(stream); + } + } +}; + +template +class MaskedSelectedGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto mask = ctx.Input("Mask"); + auto y_grad = ctx.Input(framework::GradVarName("Y")); + auto x_grad = ctx.Output(framework::GradVarName("X")); + + x_grad->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = + ctx.template device_context(); + auto stream = dev_ctx.stream(); + + Tensor mask_int32, out_size; + std::vector out_size_vec; + mask_int32.mutable_data(mask->dims(), ctx.GetPlace()); + out_size.mutable_data({1}, ctx.GetPlace()); + { + const auto& cast_runner = + NpuOpRunner("Cast", {*mask}, {mask_int32}, + {{"dst_type", static_cast(ConvertToNpuDtype( + framework::proto::VarType::INT32))}}); + cast_runner.Run(stream); + + mask_int32.Resize({mask_int32.numel()}); + NpuOpRunner sum_runner; + sum_runner.SetType("ReduceSum"); + sum_runner.AddInput(mask_int32); + sum_runner.AddInput(std::vector({0})); + sum_runner.AddOutput(out_size); + sum_runner.AddAttr("keep_dims", false); + sum_runner.Run(stream); + TensorToVector(out_size, dev_ctx, &out_size_vec); + } + + Tensor topkv2_out, indices; + topkv2_out.mutable_data({out_size_vec[0]}, ctx.GetPlace()); + indices.mutable_data({out_size_vec[0]}, ctx.GetPlace()); + { + NpuOpRunner topkv2_runner; + topkv2_runner.SetType("TopKV2") + .AddInput(mask_int32) + .AddInput(out_size) + .AddOutput(topkv2_out) + .AddOutput(indices) + .AddAttr("sorted", false) + .AddAttr("dim", 0) + .AddAttr("largest", true) + .Run(stream); + + NpuOpRunner topkv2_runner2; + topkv2_runner2.SetType("TopKV2") + .AddInput(indices) + .AddInput(out_size) + .AddOutput(topkv2_out) + .AddOutput(indices) + .AddAttr("sorted", true) + .AddAttr("dim", 0) + .AddAttr("largest", false) + .Run(stream); + + topkv2_out.Resize({out_size_vec[0], 1}); + x_grad->Resize({x_grad->numel()}); + NpuOpRunner scatter_runner; + scatter_runner.SetType("ScatterNd"); + scatter_runner.AddInput(topkv2_out); + scatter_runner.AddInput(*y_grad); + scatter_runner.AddInput( + std::vector({static_cast(x_grad->numel())})); + scatter_runner.AddOutput(*x_grad); + scatter_runner.Run(stream); + x_grad->Resize(mask->dims()); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_NPU_KERNEL(masked_select, + ops::MaskedSelectedNPUKernel, + ops::MaskedSelectedNPUKernel, + ops::MaskedSelectedNPUKernel, + ops::MaskedSelectedNPUKernel); +REGISTER_OP_NPU_KERNEL(masked_select_grad, + ops::MaskedSelectedGradNPUKernel, + ops::MaskedSelectedGradNPUKernel, + ops::MaskedSelectedGradNPUKernel, + ops::MaskedSelectedGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_masked_select_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_masked_select_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..13078aea6903ae464a86f2f2927955ec79f78a28 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_masked_select_op_npu.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + + +def np_masked_select(x, mask): + result = np.empty(shape=(0), dtype=x.dtype) + for ele, ma in zip(np.nditer(x), np.nditer(mask)): + if ma: + result = np.append(result, ele) + return result.flatten() + + +class TestMaskedSelectOp(OpTest): + def set_npu(self): + self.__class__.use_npu = True + + def setUp(self): + self.set_npu() + self.init() + self.init_dtype() + self.place = paddle.NPUPlace(0) + self.op_type = "masked_select" + x = np.random.random(self.shape).astype(self.dtype) + mask = np.array(np.random.randint(2, size=self.shape, dtype=bool)) + out = np_masked_select(x, mask) + self.inputs = {'X': x, 'Mask': mask} + self.outputs = {'Y': out} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Y') + + def init(self): + self.shape = (50, 3) + + def init_dtype(self): + self.dtype = np.float32 + + +class TestMaskedSelectOp1(TestMaskedSelectOp): + def init(self): + self.shape = (6, 8, 9, 18) + + +class TestMaskedSelectOp2(TestMaskedSelectOp): + def init(self): + self.shape = (168, ) + + +class TestMaskedSelectOpFp16(TestMaskedSelectOp): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_grad(self): + x_grad = self.inputs['Mask'].astype(self.dtype) + x_grad = x_grad * (1 / x_grad.sum()) + self.check_grad_with_place( + self.place, ['X'], 'Y', user_defined_grads=[x_grad]) + + +@skip_check_grad_ci(reason="get_numeric_gradient not support int32") +class TestMaskedSelectOpInt32(TestMaskedSelectOp): + def init_dtype(self): + self.dtype = np.int32 + + def test_check_grad(self): + pass + + +@skip_check_grad_ci(reason="get_numeric_gradient not support int64") +class TestMaskedSelectOpInt64(TestMaskedSelectOp): + def init_dtype(self): + self.dtype = np.int64 + + def test_check_grad(self): + pass + + +class TestMaskedSelectAPI(unittest.TestCase): + def test_imperative_mode(self): + paddle.disable_static(paddle.NPUPlace(0)) + shape = (88, 6, 8) + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + x = paddle.to_tensor(np_x) + mask = paddle.to_tensor(np_mask) + out = paddle.masked_select(x, mask) + np_out = np_masked_select(np_x, np_mask) + self.assertEqual(np.allclose(out.numpy(), np_out), True) + paddle.enable_static() + + def test_static_mode(self): + shape = [8, 9, 6] + x = paddle.fluid.data(shape=shape, dtype='float32', name='x') + mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask') + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + + out = paddle.masked_select(x, mask) + np_out = np_masked_select(np_x, np_mask) + + exe = paddle.static.Executor(place=paddle.NPUPlace(0)) + + res = exe.run(paddle.static.default_main_program(), + feed={"x": np_x, + "mask": np_mask}, + fetch_list=[out]) + self.assertEqual(np.allclose(res, np_out), True) + + +class TestMaskedSelectError(unittest.TestCase): + def test_error(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + shape = [8, 9, 6] + x = paddle.fluid.data(shape=shape, dtype='float32', name='x') + mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask') + mask_float = paddle.fluid.data( + shape=shape, dtype='float32', name='mask_float') + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + + def test_x_type(): + paddle.masked_select(np_x, mask) + + self.assertRaises(TypeError, test_x_type) + + def test_mask_type(): + paddle.masked_select(x, np_mask) + + self.assertRaises(TypeError, test_mask_type) + + def test_mask_dtype(): + paddle.masked_select(x, mask_float) + + self.assertRaises(TypeError, test_mask_dtype) + + +if __name__ == '__main__': + unittest.main()