未验证 提交 7bf84e2d 编写于 作者: T TTerror 提交者: GitHub

add argmax and iou_similarity for kunlun (#35836)

* add argmax and iou_similarity for kunlun

* add argmax and iou_similarity for kunlun

* add argmax and iou_similarity for kunlun
上级 1548407d
......@@ -35,7 +35,7 @@ ELSE ()
ENDIF()
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210909")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210917")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the Licnse. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/arg_min_max_op_base.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class ArgMaxXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto dtype = ctx.Attr<int>("dtype");
PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 3), true,
platform::errors::InvalidArgument(
"The attribute of dtype in xpu argmin/argmax must be [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<framework::proto::VarType::Type>(dtype))));
out->template mutable_data<int64_t>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
const bool& flatten = ctx.Attr<bool>("flatten");
framework::DDim x_dims;
if (flatten) {
x_dims = framework::make_ddim({x->numel()});
// if flatten, the axis just as 0
axis = 0;
} else {
x_dims = x->dims();
if (axis < 0) axis += x_dims.size();
}
auto xdims_vec = framework::vectorize<int>(x_dims);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::argmax(dev_ctx.x_context(), x->data<T>(), out->data<int64_t>(),
xdims_vec, axis);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU argmax kernel return wrong value[%d %s].", r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
arg_max, ops::ArgMaxXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
......@@ -17,8 +17,6 @@ endfunction()
detection_library(bipartite_match_op SRCS bipartite_match_op.cc)
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op.cu)
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu)
......@@ -58,6 +56,12 @@ else()
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc)
endif()
if(WITH_XPU)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_xpu.cc)
else()
detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu)
endif()
detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu)
#Export local libraries to parent
# set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/detection/iou_similarity_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class XPUIOUSimilarityKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::LoDTensor* in_x = ctx.Input<framework::LoDTensor>("X");
const framework::Tensor* in_y = ctx.Input<framework::Tensor>("Y");
bool normalized = ctx.Attr<bool>("box_normalized");
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
int x_n = in_x->dims()[0];
int y_n = in_y->dims()[0];
T eps = static_cast<T>(1e-10);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::iou_similarity(
dev_ctx.x_context(), in_x->data<T>(), in_y->data<T>(),
out->mutable_data<T>(ctx.GetPlace()), x_n, y_n, eps, normalized);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU iou_similarity kernel return wrong value[%d %s].", r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using XPU = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(iou_similarity, ops::XPUIOUSimilarityKernel<XPU, float>);
#endif
......@@ -318,7 +318,10 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}
{"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}
// AddMore
};
......
......@@ -107,7 +107,9 @@ XPUOpMap& get_kl2_ops() {
{"transpose2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}
// AddMore
};
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle
import paddle.fluid.core as core
paddle.enable_static()
class XPUBaseTestCase(XPUOpTest):
def initTestCase(self):
self.dims = (3, 4)
self.dtype = 'float32'
self.axis = 1
def setUp(self):
self.initTestCase()
self.__class__.op_type = 'arg_max'
self.__class__.use_xpu = True
np.random.seed(2021)
self.x = (np.random.random(self.dims)).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis, 'use_xpu': True}
if self.op_type == "arg_min":
self.outputs = {'Out': np.argmin(self.x, axis=self.axis)}
else:
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
# test argmax, dtype: float32
class TestArgMaxFloat32Case1(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = -1
class TestArgMaxFloat32Case2(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 0
class TestArgMaxFloat32Case3(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 1
class TestArgMaxFloat32Case4(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 2
class TestArgMaxFloat32Case5(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'float32'
self.axis = -1
class TestArgMaxFloat32Case6(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'float32'
self.axis = 0
class TestArgMaxFloat32Case7(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'float32'
self.axis = 1
class TestArgMaxFloat32Case8(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (1, )
self.dtype = 'float32'
self.axis = 0
class TestArgMaxFloat32Case9(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (2, )
self.dtype = 'float32'
self.axis = 0
class TestArgMaxFloat32Case10(XPUBaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, )
self.dtype = 'float32'
self.axis = 0
class TestArgMaxAPI(unittest.TestCase):
def initTestCase(self):
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 0
def setUp(self):
self.initTestCase()
self.__class__.use_Xpu = True
self.place = [paddle.XPUPlace(0)]
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2021)
numpy_input = (np.random.random(self.dims)).astype(self.dtype)
tensor_input = paddle.to_tensor(numpy_input)
numpy_output = np.argmax(numpy_input, axis=self.axis)
paddle_output = paddle.argmax(tensor_input, axis=self.axis)
self.assertEqual(
np.allclose(numpy_output, paddle_output.numpy()), True)
paddle.enable_static()
for place in self.place:
run(place)
class TestArgMaxAPI_2(unittest.TestCase):
def initTestCase(self):
self.dims = (3, 4, 5)
self.dtype = 'float32'
self.axis = 0
self.keep_dims = True
def setUp(self):
self.initTestCase()
self.__class__.use_xpu = True
self.place = [paddle.XPUPlace(0)]
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2021)
numpy_input = (np.random.random(self.dims)).astype(self.dtype)
tensor_input = paddle.to_tensor(numpy_input)
numpy_output = np.argmax(
numpy_input, axis=self.axis).reshape(1, 4, 5)
paddle_output = paddle.argmax(
tensor_input, axis=self.axis, keepdim=self.keep_dims)
self.assertEqual(
np.allclose(numpy_output, paddle_output.numpy()), True)
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
paddle.enable_static()
for place in self.place:
run(place)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import unittest
import numpy as np
import numpy.random as random
import sys
import math
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle
paddle.enable_static()
class TestXPUIOUSimilarityOp(XPUOpTest):
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def setUp(self):
self.op_type = "iou_similarity"
self.boxes1 = random.rand(2, 4).astype('float32')
self.boxes2 = random.rand(3, 4).astype('float32')
self.output = random.rand(2, 3).astype('float32')
self.box_normalized = False
# run python iou computation
self._compute_iou()
self.inputs = {'X': self.boxes1, 'Y': self.boxes2}
self.attrs = {"box_normalized": self.box_normalized, 'use_xpu': True}
self.outputs = {'Out': self.output}
def _compute_iou(self, ):
for row in range(self.boxes1.shape[0]):
for col in range(self.boxes2.shape[0]):
xmin1, ymin1, xmax1, ymax1 = self.boxes1[row]
xmin2, ymin2, xmax2, ymax2 = self.boxes2[col]
if not self.box_normalized:
area1 = (ymax1 - ymin1 + 1) * (xmax1 - xmin1 + 1)
area2 = (ymax2 - ymin2 + 1) * (xmax2 - xmin2 + 1)
else:
area1 = (ymax1 - ymin1) * (xmax1 - xmin1)
area2 = (ymax2 - ymin2) * (xmax2 - xmin2)
inter_xmax = min(xmax1, xmax2)
inter_ymax = min(ymax1, ymax2)
inter_xmin = max(xmin1, xmin2)
inter_ymin = max(ymin1, ymin2)
inter_height = inter_ymax - inter_ymin
inter_width = inter_xmax - inter_xmin
if not self.box_normalized:
inter_height += 1
inter_width += 1
inter_height = max(inter_height, 0)
inter_width = max(inter_width, 0)
inter_area = inter_width * inter_height
union_area = area1 + area2 - inter_area
sim_score = inter_area / union_area
self.output[row, col] = sim_score
class TestXPUIOUSimilarityOpWithLoD(TestXPUIOUSimilarityOp):
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
def setUp(self):
super(TestXPUIOUSimilarityOpWithLoD, self).setUp()
self.boxes1_lod = [[1, 1]]
self.output_lod = [[1, 1]]
self.box_normalized = False
# run python iou computation
self._compute_iou()
self.inputs = {'X': (self.boxes1, self.boxes1_lod), 'Y': self.boxes2}
self.attrs = {"box_normalized": self.box_normalized}
self.outputs = {'Out': (self.output, self.output_lod)}
class TestXPUIOUSimilarityOpWithBoxNormalized(TestXPUIOUSimilarityOp):
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
def setUp(self):
super(TestXPUIOUSimilarityOpWithBoxNormalized, self).setUp()
self.boxes1_lod = [[1, 1]]
self.output_lod = [[1, 1]]
self.box_normalized = True
# run python iou computation
self._compute_iou()
self.inputs = {'X': (self.boxes1, self.boxes1_lod), 'Y': self.boxes2}
self.attrs = {"box_normalized": self.box_normalized}
self.outputs = {'Out': (self.output, self.output_lod)}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册