diff --git a/paddle/fluid/operators/where_index_op_npu.cc b/paddle/fluid/operators/where_index_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..da252094df96a2e0a0c708dad1572d74ecfcd702 --- /dev/null +++ b/paddle/fluid/operators/where_index_op_npu.cc @@ -0,0 +1,97 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/where_index_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class NPUWhereIndexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = + context.template device_context(); + auto* condition = context.Input("Condition"); + auto* out = context.Output("Out"); + + auto dims = condition->dims(); + const int rank = dims.size(); + + auto place = context.GetPlace(); + const aclrtStream& stream = dev_ctx.stream(); + + // Run Cast and ReduceSum to get 0 dim of Out + Tensor booled_cond; + if (condition->type() != framework::proto::VarType::BOOL) { + auto bool_type = ConvertToNpuDtype(framework::proto::VarType::BOOL); + booled_cond.mutable_data(dims, place); + const auto& booled_runner = + NpuOpRunner("Cast", {*condition}, {booled_cond}, + {{"dst_type", static_cast(bool_type)}}); + booled_runner.Run(stream); + } else { + booled_cond.ShareDataWith(*condition); + } + Tensor casted_cond; + auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT64); + casted_cond.mutable_data(dims, place); + const auto& cast_runner = + NpuOpRunner("Cast", {booled_cond}, {casted_cond}, + {{"dst_type", static_cast(dst_dtype)}}); + cast_runner.Run(stream); + + Tensor sumed_true_num; + sumed_true_num.mutable_data({1}, place); + Tensor cond_axes; + cond_axes.mutable_data({dims.size()}, place); + std::vector axes_vec; + for (int i = 0; i < dims.size(); ++i) { + axes_vec.push_back(i); + } + framework::TensorFromVector(axes_vec, dev_ctx, &cond_axes); + const auto& sum_runner = + NpuOpRunner("ReduceSum", {casted_cond, cond_axes}, {sumed_true_num}, + {{"keep_dims", false}}); + sum_runner.Run(stream); + + Tensor local_true_num; + TensorCopySync(sumed_true_num, platform::CPUPlace(), &local_true_num); + auto true_num = *local_true_num.data(); + + out->Resize(framework::make_ddim({true_num, rank})); + out->mutable_data(place); + + if (true_num == 0) { + return; + } + + out->set_layout(DataLayout::kAnyLayout); + NpuOpRunner runner{"Where", {*condition}, {*out}}; + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL(where_index, ops::NPUWhereIndexKernel, + ops::NPUWhereIndexKernel, + ops::NPUWhereIndexKernel, + ops::NPUWhereIndexKernel, + ops::NPUWhereIndexKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_where_index_npu.py b/python/paddle/fluid/tests/unittests/npu/test_where_index_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..20d7fb6879d443fd87de0a64102b1c9214c5e2ae --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_where_index_npu.py @@ -0,0 +1,106 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import paddle +import sys +sys.path.append("..") +from op_test import OpTest +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +paddle.enable_static() + + +class TestWhereIndexOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "where_index" + self.place = paddle.NPUPlace(0) + self.init_config() + + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_config(self): + self.inputs = {'Condition': np.array([True, False, True]), } + + self.outputs = {'Out': np.array([[0], [2]], dtype='int64')} + + def set_npu(self): + self.__class__.use_npu = True + + +class TestNotBool(TestWhereIndexOp): + def init_config(self): + self.inputs = {'Condition': np.array([1, 0, 8]), } + + self.outputs = {'Out': np.array([[0], [2]], dtype='int64')} + + +class TestAllFalse(TestWhereIndexOp): + def init_config(self): + self.inputs = {'Condition': np.array([False, False, False]), } + + self.outputs = {'Out': np.array([], dtype='int64')} + + +class TestRank2(TestWhereIndexOp): + def init_config(self): + self.inputs = {'Condition': np.array([[True, False], [False, True]]), } + + self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')} + + +class TestRank3(TestWhereIndexOp): + def init_config(self): + self.inputs = { + 'Condition': np.array([[[True, False], [False, True]], + [[False, True], [True, False]], + [[False, False], [False, True]]]), + } + + self.outputs = { + 'Out': np.array( + [[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [2, 1, 1]], + dtype='int64') + } + + +class TestWhereOpError(unittest.TestCase): + def test_api(self): + with program_guard(Program(), Program()): + cond = fluid.layers.data(name='cond', shape=[4], dtype='bool') + result = fluid.layers.where(cond) + + exe = fluid.Executor(paddle.NPUPlace(0)) + exe.run(fluid.default_startup_program()) + cond_i = np.array([True, False, False, False]).astype("bool") + out = exe.run(fluid.default_main_program(), feed={'cond': cond_i}) + + +class TestWhereRaiseError(unittest.TestCase): + def test_errors(self): + def test_type(): + fluid.layers.where([10]) + + self.assertRaises(TypeError, test_type) + + +if __name__ == "__main__": + unittest.main()