From e02dec01e7364beeb9845cfc42ca00da58b7d159 Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Thu, 20 Jan 2022 20:05:54 +0800 Subject: [PATCH] [MLU]add mlu kernel for top_k and top_k_v2 (#39065) --- paddle/fluid/operators/top_k_op_mlu.cc | 77 +++++ paddle/fluid/operators/top_k_v2_op_mlu.cc | 85 ++++++ .../tests/unittests/mlu/test_top_k_op_mlu.py | 73 +++++ .../unittests/mlu/test_top_k_v2_op_mlu.py | 285 ++++++++++++++++++ 4 files changed, 520 insertions(+) create mode 100644 paddle/fluid/operators/top_k_op_mlu.cc create mode 100644 paddle/fluid/operators/top_k_v2_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_top_k_op_mlu.py create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_top_k_v2_op_mlu.py diff --git a/paddle/fluid/operators/top_k_op_mlu.cc b/paddle/fluid/operators/top_k_op_mlu.cc new file mode 100644 index 0000000000..affe5a4bc6 --- /dev/null +++ b/paddle/fluid/operators/top_k_op_mlu.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/top_k_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class TopkMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + const auto& place = ctx.GetPlace(); + + size_t k = static_cast(ctx.Attr("k")); + auto* k_t = ctx.Input("K"); + if (k_t) { + auto k_t_ptr = static_cast(k_t->data()); + auto size = k_t->numel() * sizeof(int); + memory::Copy(platform::CPUPlace(), reinterpret_cast(&k), + BOOST_GET_CONST(platform::MLUPlace, k_t->place()), k_t_ptr, + size, nullptr); + framework::DDim output_dims = output->dims(); + output_dims[output_dims.size() - 1] = k; + output->Resize(output_dims); + indices->Resize(output_dims); + } + + output->mutable_data(place); + indices->mutable_data(place); + + const bool largest = true; + const bool sorted = true; + const int axis = -1; + // cnnl only support int32/int16 type of indices + framework::Tensor indices_int32(VT::INT32); + indices_int32.Resize(indices->dims()); + indices_int32.mutable_data(place); + + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc values_output_desc(*output); + MLUCnnlTensorDesc indices_int32_desc(indices_int32); + MLUCnnl::TopK(ctx, k, axis, largest, sorted, input_desc.get(), + GetBasePtr(input), values_output_desc.get(), + GetBasePtr(output), indices_int32_desc.get(), + GetBasePtr(&indices_int32)); + + // cast indices type to int64 + MLUCnnlTensorDesc cast_output_desc(*indices); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64); + MLUCnnl::Cast(ctx, cast_type, indices_int32_desc.get(), + GetBasePtr(&indices_int32), cast_output_desc.get(), + GetBasePtr(indices)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(top_k, ops::TopkMLUKernel, + ops::TopkMLUKernel); diff --git a/paddle/fluid/operators/top_k_v2_op_mlu.cc b/paddle/fluid/operators/top_k_v2_op_mlu.cc new file mode 100644 index 0000000000..08c960186b --- /dev/null +++ b/paddle/fluid/operators/top_k_v2_op_mlu.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/top_k_v2_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class TopkV2MLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + const auto& place = ctx.GetPlace(); + + const auto& sorted = static_cast(ctx.Attr("sorted")); + const auto& largest = static_cast(ctx.Attr("largest")); + + // axis < 0, cacluate the real axis + int axis = static_cast(ctx.Attr("axis")); + if (axis < 0) { + const auto& in_dims = input->dims(); + axis += in_dims.size(); + } + + size_t k = static_cast(ctx.Attr("k")); + auto* k_t = ctx.Input("K"); + if (k_t) { + auto k_t_ptr = static_cast(k_t->data()); + auto size = k_t->numel() * sizeof(int); + memory::Copy(platform::CPUPlace(), reinterpret_cast(&k), + BOOST_GET_CONST(platform::MLUPlace, k_t->place()), k_t_ptr, + size, nullptr); + framework::DDim output_dims = output->dims(); + // accroding to axis to set K value in the dim + output_dims[axis] = k; + output->Resize(output_dims); + indices->Resize(output_dims); + } + + output->mutable_data(place); + indices->mutable_data(place); + + // cnnl only support int32/int16 type of indices + framework::Tensor indices_int32(VT::INT32); + indices_int32.Resize(indices->dims()); + indices_int32.mutable_data(place); + + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc values_output_desc(*output); + MLUCnnlTensorDesc indices_int32_desc(indices_int32); + MLUCnnl::TopK(ctx, k, axis, largest, sorted, input_desc.get(), + GetBasePtr(input), values_output_desc.get(), + GetBasePtr(output), indices_int32_desc.get(), + GetBasePtr(&indices_int32)); + + // cast indices type to int64 + MLUCnnlTensorDesc cast_output_desc(*indices); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64); + MLUCnnl::Cast(ctx, cast_type, indices_int32_desc.get(), + GetBasePtr(&indices_int32), cast_output_desc.get(), + GetBasePtr(indices)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(top_k_v2, ops::TopkV2MLUKernel, + ops::TopkV2MLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_top_k_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_top_k_op_mlu.py new file mode 100644 index 0000000000..8ad0e787ab --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_top_k_op_mlu.py @@ -0,0 +1,73 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append('..') +from op_test import OpTest +import paddle +import paddle.fluid.core as core + + +class TestTopkOp(OpTest): + def setUp(self): + self.variable_k = False + self.set_args() + self.op_type = "top_k" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + self.dtype = np.float32 + self.init_dtype() + + k = self.top_k + input = np.random.random((self.row, k)).astype(self.dtype) + output = np.ndarray((self.row, k)) + indices = np.ndarray((self.row, k)).astype("int64") + self.inputs = {'X': input} + + if self.variable_k: + self.inputs['K'] = np.array([k]).astype("int32") + else: + self.attrs = {'k': k} + + for rowid in range(self.row): + row = input[rowid] + output[rowid] = np.sort(row)[::-1][:k] + indices[rowid] = row.argsort()[::-1][:k] + + self.outputs = {'Out': output, 'Indices': indices} + + def init_dtype(self): + pass + + def set_args(self): + self.row = 100 + self.top_k = 1 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestTopkFP16Op(TestTopkOp): + def init_dtype(self): + self.dtype = np.float16 + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_top_k_v2_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_top_k_v2_op_mlu.py new file mode 100644 index 0000000000..8979344bd4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_top_k_v2_op_mlu.py @@ -0,0 +1,285 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append('..') +from op_test import OpTest +import paddle +import paddle.fluid.core as core + + +def numpy_topk(x, k=1, axis=-1, largest=True): + if axis < 0: + axis = len(x.shape) + axis + if largest: + indices = np.argsort(-x, axis=axis) + else: + indices = np.argsort(x, axis=axis) + if largest: + value = -np.sort(-x, axis=axis) + else: + value = np.sort(x, axis=axis) + indices = indices.take(indices=range(0, k), axis=axis) + value = value.take(indices=range(0, k), axis=axis) + return value, indices + + +class TestTopkOp(OpTest): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + self.dtype = np.float32 + self.input_data = np.random.rand(10, 20).astype(self.dtype) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + def test_check_output(self): + paddle.enable_static() + self.check_output_with_place(self.place) + + +class TestTopkOp1(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 0 + self.largest = False + + +class TestTopkOp2(TestTopkOp): + def init_args(self): + self.k = 4 + self.axis = 0 + self.largest = False + + +class TestTopkOp3(OpTest): + def init_args(self): + self.k = 6 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(16, 100).astype(self.dtype) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp4(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5).astype(self.dtype) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp5(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5).astype(self.dtype) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp6(OpTest): + def init_args(self): + self.k = 100 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.dtype = np.float32 + self.input_data = np.random.rand(80, 16384).astype(self.dtype) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopKAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.dtype = np.float32 + self.input_data = np.random.rand(6, 7, 8).astype(self.dtype) + self.large_input_data = np.random.rand(2, 1030).astype(self.dtype) + + def run_dygraph(self, place): + paddle.disable_static(place) + input_tensor = paddle.to_tensor(self.input_data) + large_input_tensor = paddle.to_tensor(self.large_input_data) + # test case for basic test case 1 + paddle_result = paddle.topk(input_tensor, k=2) + numpy_result = numpy_topk(self.input_data, k=2) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + # test case for basic test case 2 with axis + paddle_result = paddle.topk(input_tensor, k=2, axis=1) + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + # test case for basic test case 3 with tensor K + k_tensor = paddle.to_tensor(np.array([2])) + paddle_result = paddle.topk(input_tensor, k=k_tensor, axis=1) + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + # test case for basic test case 4 with tensor largest + k_tensor = paddle.to_tensor(np.array([2])) + paddle_result = paddle.topk(input_tensor, k=2, axis=1, largest=False) + numpy_result = numpy_topk(self.input_data, k=2, axis=1, largest=False) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + # test case for basic test case 5 with axis -1 + k_tensor = paddle.to_tensor(np.array([2])) + paddle_result = paddle.topk(input_tensor, k=2, axis=-1, largest=False) + numpy_result = numpy_topk(self.input_data, k=2, axis=-1, largest=False) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + # test case for basic test case 6 for the partial sort + paddle_result = paddle.topk(large_input_tensor, k=1, axis=-1) + numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + # test case for basic test case 7 for the unsorted + paddle_result = paddle.topk(input_tensor, k=2, axis=1, sorted=False) + sort_paddle = numpy_topk( + np.array(paddle_result[0].numpy()), axis=1, k=2) + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) + + def run_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + input_tensor = paddle.static.data( + name="x", shape=[6, 7, 8], dtype="float32") + large_input_tensor = paddle.static.data( + name="large_x", shape=[2, 1030], dtype="float32") + k_tensor = paddle.static.data(name="k", shape=[1], dtype="int32") + result1 = paddle.topk(input_tensor, k=2) + result2 = paddle.topk(input_tensor, k=2, axis=-1) + result3 = paddle.topk(input_tensor, k=k_tensor, axis=1) + self.assertEqual(result3[0].shape, (6, -1, 8)) + self.assertEqual(result3[1].shape, (6, -1, 8)) + result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False) + result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False) + result6 = paddle.topk(large_input_tensor, k=1, axis=-1) + result7 = paddle.topk(input_tensor, k=2, axis=1, sorted=False) + exe = paddle.static.Executor(place) + input_data = np.random.rand(10, 20).astype("float32") + large_input_data = np.random.rand(2, 100).astype("float32") + paddle_result = exe.run( + feed={ + "x": self.input_data, + "large_x": self.large_input_data, + "k": np.array([2]).astype("int32") + }, + fetch_list=[ + result1[0], result1[1], result2[0], result2[1], result3[0], + result3[1], result4[0], result4[1], result5[0], result5[1], + result6[0], result6[1], result7[0], result7[1] + ]) + numpy_result = numpy_topk(self.input_data, k=2) + self.assertTrue(np.allclose(paddle_result[0], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1], numpy_result[1])) + numpy_result = numpy_topk(self.input_data, k=2, axis=-1) + self.assertTrue(np.allclose(paddle_result[2], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[3], numpy_result[1])) + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(paddle_result[4], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[5], numpy_result[1])) + numpy_result = numpy_topk( + self.input_data, k=2, axis=1, largest=False) + self.assertTrue(np.allclose(paddle_result[6], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[7], numpy_result[1])) + numpy_result = numpy_topk( + self.input_data, k=2, axis=-1, largest=False) + self.assertTrue(np.allclose(paddle_result[8], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[9], numpy_result[1])) + numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1) + self.assertTrue(np.allclose(paddle_result[10], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[11], numpy_result[1])) + sort_paddle = numpy_topk(paddle_result[12], axis=1, k=2) + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) + + def test_cases(self): + places = [core.CPUPlace()] + if core.is_compiled_with_mlu(): + places.append(core.MLUPlace(0)) + for place in places: + self.run_dygraph(place) + self.run_static(place) + + def test_errors(self): + paddle.disable_static() + x = paddle.to_tensor([1, 2, 3], dtype="float32") + with self.assertRaises(BaseException): + paddle.topk(x, k=-1) + + with self.assertRaises(BaseException): + paddle.topk(x, k=0) + + +if __name__ == "__main__": + unittest.main() -- GitLab