diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 02abf08a99ce8fcb1b3ca7d8e38c0b3103a6bb46..041c77943fa04cfabef8a207273987e535fb4d1f 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -35,7 +35,7 @@ ELSE () ENDIF() SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") -SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210909") +SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210917") SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) diff --git a/paddle/fluid/operators/arg_max_op_xpu.cc b/paddle/fluid/operators/arg_max_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..8060b5cf755c0ef4f0bb0c87405c8da809db33c8 --- /dev/null +++ b/paddle/fluid/operators/arg_max_op_xpu.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the Licnse. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/arg_min_max_op_base.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class ArgMaxXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto dtype = ctx.Attr("dtype"); + PADDLE_ENFORCE_EQ( + (dtype < 0 || dtype == 3), true, + platform::errors::InvalidArgument( + "The attribute of dtype in xpu argmin/argmax must be [%s], but " + "received [%s]", + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64), + paddle::framework::DataTypeToString( + static_cast(dtype)))); + + out->template mutable_data(ctx.GetPlace()); + auto axis = ctx.Attr("axis"); + const bool& flatten = ctx.Attr("flatten"); + framework::DDim x_dims; + if (flatten) { + x_dims = framework::make_ddim({x->numel()}); + // if flatten, the axis just as 0 + axis = 0; + } else { + x_dims = x->dims(); + if (axis < 0) axis += x_dims.size(); + } + auto xdims_vec = framework::vectorize(x_dims); + auto& dev_ctx = ctx.template device_context(); + int r = xpu::argmax(dev_ctx.x_context(), x->data(), out->data(), + xdims_vec, axis); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU argmax kernel return wrong value[%d %s].", r, + XPUAPIErrorMsg[r])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + arg_max, ops::ArgMaxXPUKernel); + +#endif diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index efbd653ffd3b0d43f485e02a5c603ae51f1c1a9a..c04d04f841388253e9130f9ab07f8601910d74b4 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -17,8 +17,6 @@ endfunction() detection_library(bipartite_match_op SRCS bipartite_match_op.cc) detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu) -detection_library(iou_similarity_op SRCS iou_similarity_op.cc -iou_similarity_op.cu) detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu) @@ -58,6 +56,12 @@ else() detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc) endif() +if(WITH_XPU) + detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_xpu.cc) +else() + detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu) +endif() + detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu) #Export local libraries to parent # set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/detection/iou_similarity_op_xpu.cc b/paddle/fluid/operators/detection/iou_similarity_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..59238b92c508509d264d471c45522999e7a101c2 --- /dev/null +++ b/paddle/fluid/operators/detection/iou_similarity_op_xpu.cc @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/detection/iou_similarity_op.h" + +namespace paddle { +namespace operators { + +template +class XPUIOUSimilarityKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::LoDTensor* in_x = ctx.Input("X"); + const framework::Tensor* in_y = ctx.Input("Y"); + bool normalized = ctx.Attr("box_normalized"); + framework::LoDTensor* out = ctx.Output("Out"); + + int x_n = in_x->dims()[0]; + int y_n = in_y->dims()[0]; + T eps = static_cast(1e-10); + + auto& dev_ctx = ctx.template device_context(); + int r = xpu::iou_similarity( + dev_ctx.x_context(), in_x->data(), in_y->data(), + out->mutable_data(ctx.GetPlace()), x_n, y_n, eps, normalized); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU iou_similarity kernel return wrong value[%d %s].", r, + XPUAPIErrorMsg[r])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using XPU = paddle::platform::XPUDeviceContext; + +REGISTER_OP_XPU_KERNEL(iou_similarity, ops::XPUIOUSimilarityKernel); + +#endif diff --git a/paddle/fluid/platform/xpu/xpu1_op_list.h b/paddle/fluid/platform/xpu/xpu1_op_list.h index cdd60a856fbc90865ee29a1e3a1c371352b87618..c9545d675f90edbb39f9ed594fcbcd1bf2d1f5b4 100644 --- a/paddle/fluid/platform/xpu/xpu1_op_list.h +++ b/paddle/fluid/platform/xpu/xpu1_op_list.h @@ -318,7 +318,10 @@ XPUOpMap& get_kl1_ops() { pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::UINT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, - {"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})} + {"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"iou_similarity", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})} // AddMore }; diff --git a/paddle/fluid/platform/xpu/xpu2_op_list.h b/paddle/fluid/platform/xpu/xpu2_op_list.h index 5b9e1a34bfcd5518992cff09355253ffdc96ca84..651243a4dfe6673ec415519eb6997cae9b1c27af 100644 --- a/paddle/fluid/platform/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/xpu/xpu2_op_list.h @@ -107,7 +107,9 @@ XPUOpMap& get_kl2_ops() { {"transpose2_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, - + {"iou_similarity", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})} // AddMore }; diff --git a/python/paddle/fluid/tests/unittests/xpu/test_arg_max_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_arg_max_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbdd9db8ee7f2c4d297e6c27b8f6ee006c2b19f4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_arg_max_op_xpu.py @@ -0,0 +1,194 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid.core as core + +paddle.enable_static() + + +class XPUBaseTestCase(XPUOpTest): + def initTestCase(self): + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 1 + + def setUp(self): + self.initTestCase() + self.__class__.op_type = 'arg_max' + self.__class__.use_xpu = True + np.random.seed(2021) + self.x = (np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis, 'use_xpu': True} + if self.op_type == "arg_min": + self.outputs = {'Out': np.argmin(self.x, axis=self.axis)} + else: + self.outputs = {'Out': np.argmax(self.x, axis=self.axis)} + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + +# test argmax, dtype: float32 +class TestArgMaxFloat32Case1(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = -1 + + +class TestArgMaxFloat32Case2(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case3(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 1 + + +class TestArgMaxFloat32Case4(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 2 + + +class TestArgMaxFloat32Case5(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = -1 + + +class TestArgMaxFloat32Case6(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case7(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 1 + + +class TestArgMaxFloat32Case8(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (1, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case9(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (2, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case10(XPUBaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxAPI(unittest.TestCase): + def initTestCase(self): + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + def setUp(self): + self.initTestCase() + self.__class__.use_Xpu = True + self.place = [paddle.XPUPlace(0)] + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmax(numpy_input, axis=self.axis) + paddle_output = paddle.argmax(tensor_input, axis=self.axis) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + +class TestArgMaxAPI_2(unittest.TestCase): + def initTestCase(self): + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + self.keep_dims = True + + def setUp(self): + self.initTestCase() + self.__class__.use_xpu = True + self.place = [paddle.XPUPlace(0)] + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmax( + numpy_input, axis=self.axis).reshape(1, 4, 5) + paddle_output = paddle.argmax( + tensor_input, axis=self.axis, keepdim=self.keep_dims) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_iou_similarity_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_iou_similarity_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b745dce9efef4cc20a7651b4866c0cc32df72427 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_iou_similarity_op_xpu.py @@ -0,0 +1,116 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import sys +sys.path.append("..") + +import unittest +import numpy as np +import numpy.random as random +import sys +import math +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle + +paddle.enable_static() + + +class TestXPUIOUSimilarityOp(XPUOpTest): + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def setUp(self): + self.op_type = "iou_similarity" + self.boxes1 = random.rand(2, 4).astype('float32') + self.boxes2 = random.rand(3, 4).astype('float32') + self.output = random.rand(2, 3).astype('float32') + self.box_normalized = False + # run python iou computation + self._compute_iou() + self.inputs = {'X': self.boxes1, 'Y': self.boxes2} + self.attrs = {"box_normalized": self.box_normalized, 'use_xpu': True} + self.outputs = {'Out': self.output} + + def _compute_iou(self, ): + for row in range(self.boxes1.shape[0]): + for col in range(self.boxes2.shape[0]): + xmin1, ymin1, xmax1, ymax1 = self.boxes1[row] + xmin2, ymin2, xmax2, ymax2 = self.boxes2[col] + if not self.box_normalized: + area1 = (ymax1 - ymin1 + 1) * (xmax1 - xmin1 + 1) + area2 = (ymax2 - ymin2 + 1) * (xmax2 - xmin2 + 1) + else: + area1 = (ymax1 - ymin1) * (xmax1 - xmin1) + area2 = (ymax2 - ymin2) * (xmax2 - xmin2) + + inter_xmax = min(xmax1, xmax2) + inter_ymax = min(ymax1, ymax2) + inter_xmin = max(xmin1, xmin2) + inter_ymin = max(ymin1, ymin2) + inter_height = inter_ymax - inter_ymin + inter_width = inter_xmax - inter_xmin + if not self.box_normalized: + inter_height += 1 + inter_width += 1 + inter_height = max(inter_height, 0) + inter_width = max(inter_width, 0) + inter_area = inter_width * inter_height + union_area = area1 + area2 - inter_area + sim_score = inter_area / union_area + self.output[row, col] = sim_score + + +class TestXPUIOUSimilarityOpWithLoD(TestXPUIOUSimilarityOp): + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, check_dygraph=False) + + def setUp(self): + super(TestXPUIOUSimilarityOpWithLoD, self).setUp() + self.boxes1_lod = [[1, 1]] + self.output_lod = [[1, 1]] + self.box_normalized = False + # run python iou computation + self._compute_iou() + self.inputs = {'X': (self.boxes1, self.boxes1_lod), 'Y': self.boxes2} + self.attrs = {"box_normalized": self.box_normalized} + self.outputs = {'Out': (self.output, self.output_lod)} + + +class TestXPUIOUSimilarityOpWithBoxNormalized(TestXPUIOUSimilarityOp): + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, check_dygraph=False) + + def setUp(self): + super(TestXPUIOUSimilarityOpWithBoxNormalized, self).setUp() + self.boxes1_lod = [[1, 1]] + self.output_lod = [[1, 1]] + self.box_normalized = True + # run python iou computation + self._compute_iou() + self.inputs = {'X': (self.boxes1, self.boxes1_lod), 'Y': self.boxes2} + self.attrs = {"box_normalized": self.box_normalized} + self.outputs = {'Out': (self.output, self.output_lod)} + + +if __name__ == '__main__': + unittest.main()