From 58bfaea8afcc2b30c5f73a5c52f1cafc6a8682f2 Mon Sep 17 00:00:00 2001 From: gaoyuan Date: Wed, 31 Jan 2018 21:32:22 +0800 Subject: [PATCH] update according to the code review --- paddle/operators/box_coder_op.cc | 49 ++++++++++++-------- paddle/operators/box_coder_op.cu | 4 +- paddle/operators/box_coder_op.h | 79 +++++++++++++------------------- 3 files changed, 66 insertions(+), 66 deletions(-) diff --git a/paddle/operators/box_coder_op.cc b/paddle/operators/box_coder_op.cc index 0cb20a4182e..41123f9b6e5 100644 --- a/paddle/operators/box_coder_op.cc +++ b/paddle/operators/box_coder_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -31,23 +31,21 @@ class BoxCoderOp : public framework::OperatorWithKernel { auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); auto target_box_dims = ctx->GetInputDim("TargetBox"); - PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2UL, - "The shape of PriorBox is [N, 4]"); - PADDLE_ENFORCE_EQ(prior_box_dims[1], 4UL, - "The shape of PriorBox is [N, 4]"); - PADDLE_ENFORCE_EQ(prior_box_var_dims.size(), 2UL, - "The shape of PriorBoxVar is [N, 4]"); - PADDLE_ENFORCE_EQ(prior_box_var_dims[1], 4UL, - "The shape of PriorBoxVar is [N, 4]"); - PADDLE_ENFORCE_EQ(target_box_dims.size(), 2UL, - "The shape of TargetBox is [M, 4]"); - PADDLE_ENFORCE_EQ(target_box_dims[1], 4UL, + PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, + "The rank of Input of PriorBoxVar must be 2"); + PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); + PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims); + PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, + "The rank of Input of TargetBox must be 2"); + PADDLE_ENFORCE_EQ(target_box_dims[1], 4, "The shape of TargetBox is [M, 4]"); GetBoxCodeType(ctx->Attrs().Get("code_type")); - ctx->SetOutputDim("OutputBox", framework::make_ddim({target_box_dims[0], - target_box_dims[1]})); + ctx->SetOutputDim( + "OutputBox", + framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4})); + ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); } }; @@ -58,7 +56,7 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { AddInput( "PriorBox", "(Tensor, default Tensor) " - "Box list PriorBox is a 2-D Tensor with shape [M, 4] holds N boxes, " + "Box list PriorBox is a 2-D Tensor with shape [M, 4] holds M boxes, " "each box is represented as [xmin, ymin, xmax, ymax], " "[xmin, ymin] is the left top coordinate of the anchor box, " "if the input is image feature map, they are close to the origin " @@ -66,7 +64,7 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "coordinate of the anchor box."); AddInput("PriorBoxVar", "(Tensor, default Tensor) " - "PriorBoxVar is a 2-D Tensor with shape [M, 4] holds N group " + "PriorBoxVar is a 2-D Tensor with shape [M, 4] holds M group " "of variance."); AddInput( "TargetBox", @@ -85,14 +83,29 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { .InEnum({"encode_center_size", "decode_center_size"}); AddOutput( "OutputBox", - "(Tensor, default Tensor)" + "(LoDTensor or Tensor) " "(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] " "representing the result of N target boxes encoded/decoded with " "M Prior boxes and variances."); AddComment(R"DOC( Bounding Box Coder Operator. -Encode/Decode the priorbox information with the target bounding box. +Encode/Decode the target bounding box with the priorbox information. +The Encoding schema described below: +ox = (tx - px) / pw / pxv +oy = (ty - py) / ph / pyv +ow = log(abs(tw / pw)) / pwv +oh = log(abs(th / ph)) / phv +The Decoding schema described below: +ox = (pw * pxv * tx * + px) - tw / 2 +oy = (ph * pyv * ty * + py) - th / 2 +ow = exp(pwv * tw) * pw + tw / 2 +oh = exp(phv * th) * ph + th / 2 +where tx, ty, tw, th denote the target box's center coordinates, width and +height respectively. Similarly, px, py, pw, ph denote the priorbox's(anchor) +center coordinates, width and height. pxv, pyv, pwv, phv denote the variance +of the priorbox and ox, oy, ow, oh denote the encoded/decoded coordinates, +width and height. )DOC"); } }; diff --git a/paddle/operators/box_coder_op.cu b/paddle/operators/box_coder_op.cu index 4055ded1f8b..9e2ea8cc674 100644 --- a/paddle/operators/box_coder_op.cu +++ b/paddle/operators/box_coder_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -108,7 +108,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel { auto* output_box = context.Output("OutputBox"); if (target_box->lod().size()) { - PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, + PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, "Only support 1 level of LoD."); } auto row = target_box->dims()[0]; diff --git a/paddle/operators/box_coder_op.h b/paddle/operators/box_coder_op.h index 3865da40c33..d1c9a404597 100644 --- a/paddle/operators/box_coder_op.h +++ b/paddle/operators/box_coder_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -35,67 +35,52 @@ class BoxCoderKernel : public framework::OpKernel { public: void EncodeCenterSize(const Tensor& target_box, const Tensor& prior_box, const Tensor& prior_box_var, T* output) const { - PADDLE_ENFORCE_EQ(target_box.dims().size(), 2, - "The rank of target_box must be 2."); - PADDLE_ENFORCE_EQ(prior_box.dims().size(), 2, - "The rank of prior_box must be 2."); - PADDLE_ENFORCE_EQ(prior_box_var.dims().size(), 2, - "The rank of prior_box_var must be 2."); - PADDLE_ENFORCE_EQ(prior_box.dims()[0], prior_box_var.dims()[0], - "The dims of prior_box must equal to prior_box_var."); - int64_t row = target_box.dims()[0]; int64_t col = prior_box.dims()[0]; + int64_t len = prior_box.dims()[1]; auto* target_box_data = target_box.data(); auto* prior_box_data = prior_box.data(); auto* prior_box_var_data = prior_box_var.data(); for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { - T prior_box_width = prior_box_data[j * 4 + 2] - prior_box_data[j * 4]; + T prior_box_width = + prior_box_data[j * len + 2] - prior_box_data[j * len]; T prior_box_height = - prior_box_data[j * 4 + 3] - prior_box_data[j * 4 + 1]; + prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; T prior_box_center_x = - (prior_box_data[j * 4 + 2] + prior_box_data[j * 4]) / 2; + (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; T prior_box_center_y = - (prior_box_data[j * 4 + 3] + prior_box_data[j * 4 + 1]) / 2; + (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; T target_box_center_x = - (target_box_data[i * 4 + 2] + target_box_data[i * 4]) / 2; + (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; T target_box_center_y = - (target_box_data[i * 4 + 3] + target_box_data[i * 4 + 1]) / 2; + (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2; T target_box_width = - target_box_data[i * 4 + 2] - target_box_data[i * 4]; + target_box_data[i * len + 2] - target_box_data[i * len]; T target_box_height = - target_box_data[i * 4 + 3] - target_box_data[i * 4 + 1]; + target_box_data[i * len + 3] - target_box_data[i * len + 1]; - size_t offset = i * col * 4 + j * 4; + size_t offset = i * col * len + j * len; output[offset] = (target_box_center_x - prior_box_center_x) / - prior_box_width / prior_box_var_data[j * 4]; + prior_box_width / prior_box_var_data[j * len]; output[offset + 1] = (target_box_center_y - prior_box_center_y) / - prior_box_height / prior_box_var_data[j * 4 + 1]; + prior_box_height / prior_box_var_data[j * len + 1]; output[offset + 2] = std::log(std::fabs(target_box_width / prior_box_width)) / - prior_box_var_data[j * 4 + 2]; + prior_box_var_data[j * len + 2]; output[offset + 3] = std::log(std::fabs(target_box_height / prior_box_height)) / - prior_box_var_data[j * 4 + 3]; + prior_box_var_data[j * len + 3]; } } } void DecodeCenterSize(const Tensor& target_box, const Tensor& prior_box, const Tensor& prior_box_var, T* output) const { - PADDLE_ENFORCE_EQ(target_box.dims().size(), 2, - "The rank of target_box must be 2."); - PADDLE_ENFORCE_EQ(prior_box.dims().size(), 2, - "The rank of prior_box must be 2."); - PADDLE_ENFORCE_EQ(prior_box_var.dims().size(), 2, - "The rank of prior_box_var must be 2."); - PADDLE_ENFORCE_EQ(prior_box.dims()[0], prior_box_var.dims()[0], - "The dims of prior_box must equal to prior_box_var."); - int64_t row = target_box.dims()[0]; int64_t col = prior_box.dims()[0]; + int64_t len = prior_box.dims()[1]; auto* target_box_data = target_box.data(); auto* prior_box_data = prior_box.data(); @@ -103,29 +88,30 @@ class BoxCoderKernel : public framework::OpKernel { for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { - T prior_box_width = prior_box_data[j * 4 + 2] - prior_box_data[j * 4]; + T prior_box_width = + prior_box_data[j * len + 2] - prior_box_data[j * len]; T prior_box_height = - prior_box_data[j * 4 + 3] - prior_box_data[j * 4 + 1]; + prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; T prior_box_center_x = - (prior_box_data[j * 4 + 2] + prior_box_data[j * 4]) / 2; + (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; T prior_box_center_y = - (prior_box_data[j * 4 + 3] + prior_box_data[j * 4 + 1]) / 2; + (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; - T target_box_center_x = prior_box_var_data[j * 4] * - target_box_data[i * 4] * prior_box_width + + T target_box_center_x = prior_box_var_data[j * len] * + target_box_data[i * len] * prior_box_width + prior_box_center_x; - T target_box_center_y = prior_box_var_data[j * 4 + 1] * - target_box_data[i * 4 + 1] * + T target_box_center_y = prior_box_var_data[j * len + 1] * + target_box_data[i * len + 1] * prior_box_height + prior_box_center_y; - T target_box_width = std::exp(prior_box_var_data[j * 4 + 2] * - target_box_data[i * 4 + 2]) * + T target_box_width = std::exp(prior_box_var_data[j * len + 2] * + target_box_data[i * len + 2]) * prior_box_width; - T target_box_height = std::exp(prior_box_var_data[j * 4 + 3] * - target_box_data[i * 4 + 3]) * + T target_box_height = std::exp(prior_box_var_data[j * len + 3] * + target_box_data[i * len + 3]) * prior_box_height; - size_t offset = i * col * 4 + j * 4; + size_t offset = i * col * len + j * len; output[offset] = target_box_center_x - target_box_width / 2; output[offset + 1] = target_box_center_y - target_box_height / 2; output[offset + 2] = target_box_center_x + target_box_width / 2; @@ -146,8 +132,9 @@ class BoxCoderKernel : public framework::OpKernel { } auto row = target_box->dims()[0]; auto col = prior_box->dims()[0]; + auto len = prior_box->dims()[1]; - output_box->mutable_data({row, col, 4}, context.GetPlace()); + output_box->mutable_data({row, col, len}, context.GetPlace()); auto code_type = GetBoxCodeType(context.Attr("code_type")); T* output = output_box->data(); -- GitLab