diff --git a/paddle/operators/iou_similarity_op.cc b/paddle/operators/iou_similarity_op.cc new file mode 100755 index 0000000000000000000000000000000000000000..c520b28b83e66dbf53d2e19985370be4a2f69e23 --- /dev/null +++ b/paddle/operators/iou_similarity_op.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/iou_similarity_op.h" + +namespace paddle { +namespace operators { + +class IOUSimilarityOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of IOUSimilarityOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), + "Input(Y) of IOUSimilarityOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The rank of Input(X) must be 2."); + PADDLE_ENFORCE_EQ(x_dims[1], 4UL, "The shape of X is [N, 4]"); + PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The rank of Input(Y) must be 2."); + PADDLE_ENFORCE_EQ(y_dims[1], 4UL, "The shape of Y is [M, 4]"); + + ctx->ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", framework::make_ddim({x_dims[0], y_dims[0]})); + } +}; + +class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker { + public: + IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor, default LoDTensor) " + "Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, " + "each box is represented as [xmin, ymin, xmax, ymax], " + "the shape of X is [N, 4]. [xmin, ymin] is the left top " + "coordinate of the box if the input is image feature map, they " + "are close to the origin of the coordinate system. " + "[xmax, ymax] is the right bottom coordinate of the box. " + "This tensor can contain LoD information to represent a batch " + "of inputs. One instance of this batch can contain different " + "numbers of entities."); + AddInput("Y", + "(Tensor, default Tensor) " + "Box list Y holds M boxes, each box is represented as " + "[xmin, ymin, xmax, ymax], the shape of X is [N, 4]. " + "[xmin, ymin] is the left top coordinate of the box if the " + "input is image feature map, and [xmax, ymax] is the right " + "bottom coordinate of the box."); + + AddOutput("Out", + "(LoDTensor, the lod is same as input X) The output of " + "iou_similarity op, a tensor with shape [N, M] " + "representing pairwise iou scores."); + + AddComment(R"DOC( +IOU Similarity Operator. +Computes intersection-over-union (IOU) between two box lists. + Box list 'X' should be a LoDTensor and 'Y' is a common Tensor, + boxes in 'Y' are shared by all instance of the batched inputs of X. + Given two boxes A and B, the calculation of IOU is as follows: + +$$ +IOU(A, B) = +\frac{area(A\cap B)}{area(A)+area(B)-area(A\cap B)} +$$ + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(iou_similarity, ops::IOUSimilarityOp, + ops::IOUSimilarityOpMaker); + +REGISTER_OP_CPU_KERNEL( + iou_similarity, + ops::IOUSimilarityKernel, + ops::IOUSimilarityKernel); diff --git a/paddle/operators/iou_similarity_op.cu b/paddle/operators/iou_similarity_op.cu new file mode 100755 index 0000000000000000000000000000000000000000..fa5052624618c35875b241419946f69b776c81d4 --- /dev/null +++ b/paddle/operators/iou_similarity_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/iou_similarity_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + iou_similarity, + ops::IOUSimilarityKernel, + ops::IOUSimilarityKernel); diff --git a/paddle/operators/iou_similarity_op.h b/paddle/operators/iou_similarity_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e36177069d7b18ea23759f99c4679218fbfd32b8 --- /dev/null +++ b/paddle/operators/iou_similarity_op.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/platform/for_range.h" + +template +inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2, + T ymin2, T xmax2, T ymax2) { + constexpr T zero = static_cast(0); + T area1 = (ymax1 - ymin1) * (xmax1 - xmin1); + T area2 = (ymax2 - ymin2) * (xmax2 - xmin2); + T inter_xmax = xmax1 > xmax2 ? xmax2 : xmax1; + T inter_ymax = ymax1 > ymax2 ? ymax2 : ymax1; + T inter_xmin = xmin1 > xmin2 ? xmin1 : xmin2; + T inter_ymin = ymin1 > ymin2 ? ymin1 : ymin2; + T inter_height = inter_ymax - inter_ymin; + T inter_width = inter_xmax - inter_xmin; + inter_height = inter_height > zero ? inter_height : zero; + inter_width = inter_width > zero ? inter_width : zero; + T inter_area = inter_width * inter_height; + T union_area = area1 + area2 - inter_area; + T sim_score = inter_area / union_area; + return sim_score; +} + +template +struct IOUSimilarityFunctor { + IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols) + : x_(x), y_(y), z_(z), cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + T x_min1 = x_[row_id * 4]; + T y_min1 = x_[row_id * 4 + 1]; + T x_max1 = x_[row_id * 4 + 2]; + T y_max1 = x_[row_id * 4 + 3]; + for (size_t i = 0; i < cols_; ++i) { + T x_min2 = y_[i * 4]; + T y_min2 = y_[i * 4 + 1]; + T x_max2 = y_[i * 4 + 2]; + T y_max2 = y_[i * 4 + 3]; + + T sim = IOUSimilarity(x_min1, y_min1, x_max1, y_max1, x_min2, y_min2, + x_max2, y_max2); + + z_[row_id * cols_ + i] = sim; + } + } + const T* x_; + const T* y_; + T* z_; + const size_t cols_; +}; + +namespace paddle { +namespace operators { + +template +class IOUSimilarityKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::LoDTensor* in_x = ctx.Input("X"); + const framework::Tensor* in_y = ctx.Input("Y"); + framework::LoDTensor* out = ctx.Output("Out"); + + int x_n = in_x->dims()[0]; + int y_n = in_y->dims()[0]; + IOUSimilarityFunctor functor(in_x->data(), in_y->data(), + out->mutable_data(ctx.GetPlace()), y_n); + + platform::ForRange for_range( + static_cast(ctx.device_context()), x_n); + for_range(functor); + } +}; // namespace operators + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_iou_similarity_op.py b/python/paddle/v2/fluid/tests/test_iou_similarity_op.py new file mode 100755 index 0000000000000000000000000000000000000000..128f2e4977195a563efcd26364cc6261da2dd685 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_iou_similarity_op.py @@ -0,0 +1,55 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +class TestIOUSimilarityOp(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "iou_similarity" + self.boxes1 = np.array( + [[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]).astype('float32') + self.boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], + [0.0, 0.0, 20.0, 20.0]]).astype('float32') + self.output = np.array( + [[2.0 / 16.0, 0, 6.0 / 400.0], + [1.0 / 16.0, 0.0, 5.0 / 400.0]]).astype('float32') + + self.inputs = {'X': self.boxes1, 'Y': self.boxes2} + + self.outputs = {'Out': self.output} + + +class TestIOUSimilarityOpWithLoD(TestIOUSimilarityOp): + def test_check_output(self): + self.check_output() + + def setUp(self): + super(TestIOUSimilarityOpWithLoD, self).setUp() + self.boxes1_lod = [[0, 1, 2]] + self.output_lod = [[0, 1, 2]] + + self.inputs = {'X': (self.boxes1, self.boxes1_lod), 'Y': self.boxes2} + self.outputs = {'Out': (self.output, self.output_lod)} + + +if __name__ == '__main__': + unittest.main()