提交 a05d25cf 编写于 作者: W wanghaox

update code and doc, change input x to LoDTensor

上级 d4587959
...@@ -44,11 +44,14 @@ class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -44,11 +44,14 @@ class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker) IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>) " "(LoDTensor, default LoDTensor<float>) "
"Box list X holds N boxes, each box is " "Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, "
"represented as [xmin, ymin, xmax, ymax], the shape of X is [N, " "each box is represented as [xmin, ymin, xmax, ymax], "
"4]. [xmin, ymin] is the lower left coordinate of the box, and " "the shape of X is [N, 4]. [xmin, ymin] is the lower left "
"[xmax, ymax] is the right upper coordinate of the box."); "coordinate of the box, and [xmax, ymax] is the right upper "
"coordinate of the box.This tensor can contain LoD information "
"to represent a batch of inputs. One instance of this batch can "
"contain different numbers of entities.");
AddInput("Y", AddInput("Y",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
"Box list Y holds M boxes, each box is " "Box list Y holds M boxes, each box is "
...@@ -56,14 +59,23 @@ class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -56,14 +59,23 @@ class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
"4]. [xmin, ymin] is the lower left coordinate of the box, and " "4]. [xmin, ymin] is the lower left coordinate of the box, and "
"[xmax, ymax] is the right upper coordinate of the box."); "[xmax, ymax] is the right upper coordinate of the box.");
AddOutput( AddOutput("Out",
"Out", "(LoDTensor or Tensor, the lod is same as input X) The output of "
"(Tensor) The output of iou_similarity op, a tensor with shape [N, M] " "iou_similarity op, a tensor with shape [N, M] "
"representing pairwise iou scores."); "representing pairwise iou scores.");
AddComment(R"DOC( AddComment(R"DOC(
IOU Similarity Operator. IOU Similarity Operator.
Computes intersection-over-union (IOU) between two box lists. Computes intersection-over-union (IOU) between two box lists.
Box list 'X' should be a LoDTensor and 'Y' is a common Tensor,
boxes in 'Y' are shared by all input images.
Given two box A and B, the calculation of IOU is as follows:
$$
IOU(A, B) =
\frac{area(A\cap B)}{area(A)+area(B)-area(A\cap B)}
$$
)DOC"); )DOC");
} }
}; };
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/iou_similarity_op.h" #include "paddle/operators/iou_similarity_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
...@@ -71,9 +71,9 @@ template <typename DeviceContext, typename T> ...@@ -71,9 +71,9 @@ template <typename DeviceContext, typename T>
class IOUSimilarityKernel : public framework::OpKernel<T> { class IOUSimilarityKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* in_x = ctx.Input<framework::Tensor>("X"); const framework::LoDTensor* in_x = ctx.Input<framework::LoDTensor>("X");
const framework::Tensor* in_y = ctx.Input<framework::Tensor>("Y"); const framework::Tensor* in_y = ctx.Input<framework::Tensor>("Y");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out"); framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
int x_n = in_x->dims()[0]; int x_n = in_x->dims()[0];
int y_n = in_y->dims()[0]; int y_n = in_y->dims()[0];
...@@ -83,6 +83,8 @@ class IOUSimilarityKernel : public framework::OpKernel<T> { ...@@ -83,6 +83,8 @@ class IOUSimilarityKernel : public framework::OpKernel<T> {
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), x_n); static_cast<const DeviceContext&>(ctx.device_context()), x_n);
for_range(functor); for_range(functor);
out->set_lod(in_x->lod());
} }
}; // namespace operators }; // namespace operators
......
...@@ -38,5 +38,18 @@ class TestIOUSimilarityOp(OpTest): ...@@ -38,5 +38,18 @@ class TestIOUSimilarityOp(OpTest):
self.outputs = {'Out': self.output} self.outputs = {'Out': self.output}
class TestIOUSimilarityOpWithLoD(TestIOUSimilarityOp):
def test_check_output(self):
self.check_output()
def setUp(self):
super(TestIOUSimilarityOpWithLoD, self).setUp()
self.boxes1_lod = [[0, 1, 2]]
self.output_lod = [[0, 1, 2]]
self.inputs = {'X': (self.boxes1, self.boxes1_lod), 'Y': self.boxes2}
self.outputs = {'Out': (self.output, self.output_lod)}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册