diff --git a/paddle/fluid/operators/top_k_v2_op_xpu.cc b/paddle/fluid/operators/top_k_v2_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..141a0ede4f8b0bec5058dbf5cb5bf8cb9a9fc767 --- /dev/null +++ b/paddle/fluid/operators/top_k_v2_op_xpu.cc @@ -0,0 +1,198 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include + +#include "paddle/fluid/operators/top_k_op.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "xpu/refactor/math.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +template +class TopkV2XPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + const auto& in_dims = input->dims(); + const T* in_data = input->data(); + int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); + T* output_data = output->mutable_data(ctx.GetPlace()); + const auto& out_dims = output->dims(); + + const auto& sorted = static_cast(ctx.Attr("sorted")); + const auto& largest = static_cast(ctx.Attr("largest")); + PADDLE_ENFORCE_EQ( + sorted, true, + platform::errors::External( + "XPU API does not support unsorted topk operation currently." + " Operator will be supported in future update.")); + PADDLE_ENFORCE_EQ( + largest, true, + platform::errors::External( + "XPU API does not support smallest topk operation currently." + " Operator will be supported in future update.")); + + int axis = static_cast(ctx.Attr("axis")); + if (axis < 0) axis += in_dims.size(); + + size_t k = static_cast(ctx.Attr("k")); + auto* k_t = ctx.Input("K"); + if (k_t) { + k = k_t->data()[0]; + framework::DDim output_dims = output->dims(); + output_dims[axis] = k; + output->Resize(output_dims); + indices->Resize(output_dims); + } + if (axis + 1 == in_dims.size()) { + auto& dev_ctx = ctx.template device_context(); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int32_t* indices_int_data = + RAII_GUARD.alloc_l3_or_gm(indices->numel()); + + const size_t row = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const size_t col = in_dims[in_dims.size() - 1]; + int r = xpu::sorted_topk(dev_ctx.x_context(), in_data, output_data, + indices_int_data, row, col, k); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d %s] in call kernel name " + "[%s], please check " + "where Baidu Kunlun Card is properly installed.", + r, XPUAPIErrorMsg[r], "sorted_topk")); + r = xpu::cast_v2(dev_ctx.x_context(), + (const int32_t*)indices_int_data, + indices_data, indices->numel()); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d %s] in call kernel name " + "[%s], please check " + "where Baidu Kunlun Card is properly installed.", + r, XPUAPIErrorMsg[r], "cast_v2")); + + } else { + // do transpose if axis is not the last dim of input + std::vector trans_axes; + for (int i = 0; i < axis; i++) { + trans_axes.emplace_back(i); + } + for (int i = axis + 1; i < in_dims.size(); i++) { + trans_axes.emplace_back(i); + } + trans_axes.emplace_back(axis); + // Get input and output dims for transpose + framework::DDim trans_dims(in_dims); + framework::DDim trans_out_dims(output->dims()); + for (size_t i = 0; i < trans_axes.size(); i++) { + trans_dims[i] = in_dims[trans_axes[i]]; + trans_out_dims[i] = out_dims[trans_axes[i]]; + } + + std::vector x_shape_host(in_dims.size(), 0); + for (int i = 0; i < in_dims.size(); ++i) { + x_shape_host[i] = in_dims[i]; + } + + auto& dev_ctx = ctx.template device_context(); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + T* trans_in_data = RAII_GUARD.alloc_l3_or_gm(input->numel()); + + // Transpose and save interval output to trans_in + int r = xpu::transpose(dev_ctx.x_context(), in_data, trans_in_data, + x_shape_host, trans_axes); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU API 1st Transpose kernel" + " returns wrong value[%d %s]!", + r, XPUAPIErrorMsg[r])); + + T* trans_out_data = RAII_GUARD.alloc_l3_or_gm(output->numel()); + int64_t* trans_idx_data = + RAII_GUARD.alloc_l3_or_gm(output->numel()); + int32_t* trans_idx_int32_data = + RAII_GUARD.alloc_l3_or_gm(output->numel()); + const size_t row = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const size_t col = trans_dims[trans_dims.size() - 1]; + + // Do top k on transposed input + r = xpu::sorted_topk(dev_ctx.x_context(), trans_in_data, + trans_out_data, trans_idx_int32_data, row, col, + k); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d %s] in call kernel name " + "[%s], please check " + "where Baidu Kunlun Card is properly installed.", + r, XPUAPIErrorMsg[r], "sorted_topk")); + + r = xpu::cast_v2(dev_ctx.x_context(), + (const int32_t*)trans_idx_int32_data, + trans_idx_data, indices->numel()); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d %s in call kernel name " + "[%s], please check " + "where Baidu Kunlun Card is properly installed.", + r, XPUAPIErrorMsg[r], "cast_v2")); + + // Transpose back to original dims + std::vector trans_back_axes; + for (int i = 0; i < axis; i++) { + trans_axes.emplace_back(i); + } + trans_axes.emplace_back(trans_out_dims.size() - 1); + for (int i = axis; i < trans_out_dims.size() - 1; i++) { + trans_axes.emplace_back(i); + } + + std::vector trans_out_shape_host(trans_back_axes.size(), 0); + for (size_t i = 0; i < trans_back_axes.size(); ++i) { + trans_out_shape_host[i] = trans_out_dims[i]; + } + r = xpu::transpose(dev_ctx.x_context(), trans_out_data, output_data, + trans_out_shape_host, trans_back_axes); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU API 2nd Transpose kernel" + " returns wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::transpose(dev_ctx.x_context(), trans_idx_data, + indices_data, trans_out_shape_host, + trans_back_axes); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU API 3rd Transpose kernel" + " returns wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL(top_k_v2, ops::TopkV2XPUKernel); +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 0742c2a3968b9d412ce65d60d7149616649b65b1..3d7739f5a06aff15a17764441c46bea7d63a33fb 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -327,6 +327,7 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"transpose", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"top_k_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"unsqueeze2_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), @@ -348,6 +349,7 @@ XPUOpMap& get_kl2_ops() { {"where", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + // AddMore }; diff --git a/python/paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f4b4244355ba68c1f28683831332019a556064 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py @@ -0,0 +1,289 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid.core as core + +paddle.enable_static() + + +def numpy_topk(x, k=1, axis=-1, largest=True): + if axis < 0: + axis = len(x.shape) + axis + if largest: + indices = np.argsort(-x, axis=axis) + else: + indices = np.argsort(x, axis=axis) + if largest: + value = -np.sort(-x, axis=axis) + else: + value = np.sort(x, axis=axis) + indices = indices.take(indices=range(0, k), axis=axis) + value = value.take(indices=range(0, k), axis=axis) + return value, indices + + +class TestTopkOp(OpTest): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 20) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad(set(['X']), 'Out') + + +class TestTopkOp1(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp2(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp3(TestTopkOp): + def init_args(self): + self.k = 5 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp4(TestTopkOp): + def init_args(self): + self.k = 1 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp5(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 2 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp6(TestTopkOp): + def init_args(self): + self.k = 5 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(8, 32, 64) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp7(TestTopkOp): + def init_args(self): + self.k = 10 + self.axis = 2 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(8, 5, 10, 16) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp8(TestTopkOp): + def init_args(self): + self.k = 1 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(8, 32, 64) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp9(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp10(TestTopkOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp11(TestTopkOp): + def init_args(self): + self.k = 5 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp12(TestTopkOp): + def init_args(self): + self.k = 1 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float32 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +if __name__ == "__main__": + unittest.main()