未验证 提交 72ee3c62 编写于 作者: J jerrywgz 提交者: GitHub

Merge pull request #15398 from jerrywgz/add_axis_for_boxcoder

Add axis for boxcoder
......@@ -322,7 +322,7 @@ paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_class
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
......
......@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include <vector>
namespace paddle {
namespace operators {
......@@ -32,32 +33,57 @@ class BoxCoderOp : public framework::OperatorWithKernel {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2,
"The rank of Input of PriorBoxVar must be 2");
"The rank of Input PriorBox must be 2");
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4,
"The shape of PriorBox is [N, 4]");
if (ctx->HasInput("PriorBoxVar")) {
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims);
PADDLE_ENFORCE(
prior_box_var_dims.size() == 1 || prior_box_var_dims.size() == 2,
"Input(PriorBoxVar) of BoxCoderOp should be 1 or 2.");
if (prior_box_var_dims.size() == 1) {
PADDLE_ENFORCE_EQ(
prior_box_var_dims[0], 4,
"The 1st dimension of Input(PriorBoxVar) should be 4"
"when the rank is 1.");
} else {
PADDLE_ENFORCE_EQ(
prior_box_dims, prior_box_var_dims,
"The dimension of Input(PriorBoxVar) should be equal to"
"the dimension of Input(PriorBox when the rank is 2.)");
}
}
}
auto code_type =
GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
if (code_type == BoxCodeType::kEncodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
"The rank of Input of TargetBox must be 2");
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
"The shape of TargetBox is [M, 4]");
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
"The rank of Input of TargetBox must be 3");
auto code_type = GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
int axis = ctx->Attrs().Get<int>("axis");
if (code_type == BoxCodeType::kEncodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
"The rank of Input TargetBox must be 2");
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
"The shape of TargetBox is [M, 4]");
ctx->SetOutputDim(
"OutputBox",
framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
"The rank of Input TargetBox must be 3");
if (axis == 0) {
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
} else if (axis == 1) {
PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]);
} else {
PADDLE_THROW("axis must be 0 or 1.");
}
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
ctx->ShareDim("TargetBox", /*->*/ "OutputBox");
}
if (code_type == BoxCodeType::kDecodeCenterSize && axis == 1) {
ctx->ShareLoD("PriorBox", /*->*/ "OutputBox");
} else {
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");
}
ctx->SetOutputDim(
"OutputBox",
framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");
}
};
......@@ -100,6 +126,21 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default true) "
"whether treat the priorbox as a noramlized box")
.SetDefault(true);
AddAttr<int>("axis",
"(int, default 0)"
"which axis in PriorBox to broadcast for box decode,"
"for example, if axis is 0 and TargetBox has shape"
"[N, M, 4] and PriorBox has shape [M, 4], then PriorBox "
"will broadcast to [N, M, 4] for decoding. It is only valid"
"when code type is decode_center_size")
.SetDefault(0)
.InEnum({0, 1});
AddAttr<std::vector<float>>(
"variance",
"(vector<float>, default {}),"
"variance of prior box with shape [4]. PriorBoxVar and variance can"
"not be provided at the same time.")
.SetDefault(std::vector<float>{});
AddOutput("OutputBox",
"(LoDTensor or Tensor) "
"When code_type is 'encode_center_size', the output tensor of "
......@@ -138,7 +179,11 @@ where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width
and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the
priorbox's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`,
`phv` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh` denote the
encoded/decoded coordinates, width and height.
encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target box has
shape [N, M, 4], and the shape of prior box can be [N, 4] or [M, 4]. Then prior
box will broadcast to target box along the assigned axis.
)DOC");
}
};
......
......@@ -9,6 +9,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -16,11 +19,11 @@ namespace paddle {
namespace operators {
template <typename T>
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data, const int row,
const int col, const int len,
const bool normalized, T* output) {
__global__ void EncodeCenterSizeKernel(
const T* prior_box_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const int col, const int len,
const bool normalized, const T prior_box_var_size, const float* variance,
const int var_size, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) {
const int row_idx = idx / col;
......@@ -30,11 +33,9 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
T prior_box_height = prior_box_data[col_idx * len + 3] -
prior_box_data[col_idx * len + 1] +
(normalized == false);
T prior_box_center_x =
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
prior_box_data[col_idx * len + 1]) /
2;
T prior_box_center_x = prior_box_data[col_idx * len] + prior_box_width / 2;
T prior_box_center_y =
prior_box_data[col_idx * len + 1] + prior_box_height / 2;
T target_box_center_x =
(target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) /
......@@ -55,58 +56,73 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
output[idx * len + 2] = log(fabs(target_box_width / prior_box_width));
output[idx * len + 3] = log(fabs(target_box_height / prior_box_height));
if (prior_box_var_data) {
output[idx * len] /= prior_box_var_data[col_idx * len];
output[idx * len + 1] /= prior_box_var_data[col_idx * len + 1];
output[idx * len + 2] /= prior_box_var_data[col_idx * len + 2];
output[idx * len + 3] /= prior_box_var_data[col_idx * len + 3];
int prior_var_offset = 0;
if (prior_box_var_size == 2) {
prior_var_offset = col_idx * len;
}
output[idx * len] /= prior_box_var_data[prior_var_offset];
output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1];
output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2];
output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3];
} else if (var_size == 4) {
for (int k = 0; k < 4; ++k) {
output[idx * len + k] /= static_cast<T>(variance[k]);
}
}
}
}
template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data, const int row,
const int col, const int len,
const bool normalized, T* output) {
__global__ void DecodeCenterSizeKernel(
const T* prior_box_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const int col, const int len,
const bool normalized, const T prior_box_var_size, const float* variance,
const int var_size, const int axis, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
int prior_box_offset = 0;
if (idx < row * col) {
const int col_idx = idx % col;
T prior_box_width = prior_box_data[col_idx * len + 2] -
prior_box_data[col_idx * len] + (normalized == false);
T prior_box_height = prior_box_data[col_idx * len + 3] -
prior_box_data[col_idx * len + 1] +
const int row_idx = idx / col;
prior_box_offset = axis == 0 ? col_idx * len : row_idx * len;
T prior_box_width = prior_box_data[prior_box_offset + 2] -
prior_box_data[prior_box_offset] +
(normalized == false);
T prior_box_height = prior_box_data[prior_box_offset + 3] -
prior_box_data[prior_box_offset + 1] +
(normalized == false);
T prior_box_center_x =
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2;
T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
prior_box_data[col_idx * len + 1]) /
2;
prior_box_data[prior_box_offset] + prior_box_width / 2;
T prior_box_center_y =
prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
T target_box_width, target_box_height;
T target_box_center_x, target_box_center_y;
T box_var_x = T(1), box_var_y = T(1);
T box_var_w = T(1), box_var_h = T(1);
if (prior_box_var_data) {
target_box_width = exp(prior_box_var_data[col_idx * len + 2] *
target_box_data[idx * len + 2]) *
prior_box_width;
target_box_height = exp(prior_box_var_data[col_idx * len + 3] *
target_box_data[idx * len + 3]) *
prior_box_height;
target_box_center_x = prior_box_var_data[col_idx * len] *
target_box_data[idx * len] * prior_box_width +
prior_box_center_x;
target_box_center_y = prior_box_var_data[col_idx * len + 1] *
target_box_data[idx * len + 1] *
prior_box_height +
prior_box_center_y;
} else {
target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width;
target_box_height =
exp(target_box_data[idx * len + 3]) * prior_box_height;
target_box_center_x =
target_box_data[idx * len] * prior_box_width + prior_box_center_x;
target_box_center_y = target_box_data[idx * len + 1] * prior_box_height +
prior_box_center_y;
int prior_var_offset = 0;
if (prior_box_var_size == 2) {
prior_var_offset = axis == 0 ? col_idx * len : row_idx * len;
}
box_var_x = prior_box_var_data[prior_var_offset];
box_var_y = prior_box_var_data[prior_var_offset + 1];
box_var_w = prior_box_var_data[prior_var_offset + 2];
box_var_h = prior_box_var_data[prior_var_offset + 3];
} else if (var_size == 4) {
box_var_x = static_cast<T>(variance[0]);
box_var_y = static_cast<T>(variance[1]);
box_var_w = static_cast<T>(variance[2]);
box_var_h = static_cast<T>(variance[3]);
}
target_box_width =
exp(box_var_w * target_box_data[idx * len + 2]) * prior_box_width;
target_box_height =
exp(box_var_h * target_box_data[idx * len + 3]) * prior_box_height;
target_box_center_x =
box_var_x * target_box_data[idx * len] * prior_box_width +
prior_box_center_x;
target_box_center_y =
box_var_y * target_box_data[idx * len + 1] * prior_box_height +
prior_box_center_y;
output[idx * len] = target_box_center_x - target_box_width / 2;
output[idx * len + 1] = target_box_center_y - target_box_height / 2;
......@@ -127,36 +143,64 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
std::vector<float> variance = context.Attr<std::vector<float>>("variance");
const T* prior_box_data = prior_box->data<T>();
const T* target_box_data = target_box->data<T>();
const T* prior_box_var_data = nullptr;
if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
auto prior_box_var_size = 0;
if (prior_box_var) {
PADDLE_ENFORCE(variance.empty(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time.");
prior_box_var_data = prior_box_var->data<T>();
prior_box_var_size = prior_box_var->dims().size();
}
if (!(variance.empty())) {
PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
"Size of attribute 'variance' should be 4");
}
if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
"Only support 1 level of LoD.");
}
const int var_size = static_cast<int>(variance.size());
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
int axis = context.Attr<int>("axis");
auto row = target_box->dims()[0];
auto col = prior_box->dims()[0];
if (code_type == BoxCodeType::kDecodeCenterSize) {
col = target_box->dims()[1];
}
auto len = prior_box->dims()[1];
int block = 512;
int grid = (row * col + block - 1) / block;
auto& device_ctx = context.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(device_ctx);
int bytes = var_size * sizeof(float);
auto dev_var = allocator.Allocate(bytes);
float* dev_var_data = reinterpret_cast<float*>(dev_var->ptr());
auto cplace = platform::CPUPlace();
const auto gplace = boost::get<platform::CUDAPlace>(context.GetPlace());
memory::Copy(gplace, dev_var_data, cplace, &variance[0], bytes,
device_ctx.stream());
output_box->mutable_data<T>({row, col, len}, context.GetPlace());
T* output = output_box->data<T>();
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
normalized, output);
normalized, prior_box_var_size, dev_var_data, var_size, output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
normalized, output);
normalized, prior_box_var_size, dev_var_data, var_size, axis, output);
}
}
};
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void EncodeCenterSize(const framework::Tensor* target_box,
const framework::Tensor* prior_box,
const framework::Tensor* prior_box_var,
const bool normalized, T* output) const {
const bool normalized,
const std::vector<float> variance, T* output) const {
int64_t row = target_box->dims()[0];
int64_t col = prior_box->dims()[0];
int64_t len = prior_box->dims()[1];
......@@ -53,10 +55,9 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T prior_box_height = prior_box_data[j * len + 3] -
prior_box_data[j * len + 1] +
(normalized == false);
T prior_box_center_x =
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2;
T prior_box_center_y =
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
prior_box_data[j * len + 1] + prior_box_height / 2;
T target_box_center_x =
(target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
......@@ -78,10 +79,18 @@ class BoxCoderKernel : public framework::OpKernel<T> {
output[offset + 3] =
std::log(std::fabs(target_box_height / prior_box_height));
if (prior_box_var) {
output[offset] /= prior_box_var_data[j * len];
output[offset + 1] /= prior_box_var_data[j * len + 1];
output[offset + 2] /= prior_box_var_data[j * len + 2];
output[offset + 3] /= prior_box_var_data[j * len + 3];
int prior_var_offset = 0;
if (prior_box_var->dims().size() == 2) {
prior_var_offset = j * len;
}
output[offset] /= prior_box_var_data[prior_var_offset];
output[offset + 1] /= prior_box_var_data[prior_var_offset + 1];
output[offset + 2] /= prior_box_var_data[prior_var_offset + 2];
output[offset + 3] /= prior_box_var_data[prior_var_offset + 3];
} else if (!(variance.empty())) {
for (int k = 0; k < 4; ++k) {
output[offset + k] /= static_cast<T>(variance[k]);
}
}
}
}
......@@ -89,58 +98,71 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void DecodeCenterSize(const framework::Tensor* target_box,
const framework::Tensor* prior_box,
const framework::Tensor* prior_box_var,
const bool normalized, T* output) const {
const bool normalized, const int axis,
const std::vector<float> variance, T* output) const {
int64_t row = target_box->dims()[0];
int64_t col = prior_box->dims()[0];
int64_t len = prior_box->dims()[1];
int64_t col = target_box->dims()[1];
int64_t len = target_box->dims()[2];
auto* target_box_data = target_box->data<T>();
auto* prior_box_data = prior_box->data<T>();
const T* prior_box_var_data = nullptr;
if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
int prior_box_offset = 0;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
size_t offset = i * col * len + j * len;
T prior_box_width = prior_box_data[j * len + 2] -
prior_box_data[j * len] + (normalized == false);
T prior_box_height = prior_box_data[j * len + 3] -
prior_box_data[j * len + 1] +
if (axis == 0) {
prior_box_offset = j * len;
} else if (axis == 1) {
prior_box_offset = i * len;
}
T prior_box_width = prior_box_data[prior_box_offset + 2] -
prior_box_data[prior_box_offset] +
(normalized == false);
T prior_box_height = prior_box_data[prior_box_offset + 3] -
prior_box_data[prior_box_offset + 1] +
(normalized == false);
T prior_box_center_x =
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
prior_box_data[prior_box_offset] + prior_box_width / 2;
T prior_box_center_y =
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
T target_box_center_x = 0, target_box_center_y = 0;
T target_box_width = 0, target_box_height = 0;
T box_var_x = T(1), box_var_y = T(1);
T box_var_w = T(1), box_var_h = T(1);
if (prior_box_var) {
target_box_center_x = prior_box_var_data[j * len] *
target_box_data[offset] * prior_box_width +
prior_box_center_x;
target_box_center_y = prior_box_var_data[j * len + 1] *
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
target_box_width = std::exp(prior_box_var_data[j * len + 2] *
target_box_data[offset + 2]) *
prior_box_width;
target_box_height = std::exp(prior_box_var_data[j * len + 3] *
target_box_data[offset + 3]) *
prior_box_height;
} else {
target_box_center_x =
target_box_data[offset] * prior_box_width + prior_box_center_x;
target_box_center_y = target_box_data[offset + 1] * prior_box_height +
prior_box_center_y;
target_box_width =
std::exp(target_box_data[offset + 2]) * prior_box_width;
target_box_height =
std::exp(target_box_data[offset + 3]) * prior_box_height;
int prior_var_offset = 0;
if (prior_box_var->dims().size() == 2) {
if (axis == 0)
prior_var_offset = j * len;
else if (axis == 1)
prior_var_offset = i * len;
}
box_var_x = prior_box_var_data[prior_var_offset];
box_var_y = prior_box_var_data[prior_var_offset + 1];
box_var_w = prior_box_var_data[prior_var_offset + 2];
box_var_h = prior_box_var_data[prior_var_offset + 3];
} else if (!(variance.empty())) {
box_var_x = static_cast<T>(variance[0]);
box_var_y = static_cast<T>(variance[1]);
box_var_w = static_cast<T>(variance[2]);
box_var_h = static_cast<T>(variance[3]);
}
target_box_center_x =
box_var_x * target_box_data[offset] * prior_box_width +
prior_box_center_x;
target_box_center_y =
box_var_y * target_box_data[offset + 1] * prior_box_height +
prior_box_center_y;
target_box_width =
std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width;
target_box_height = std::exp(box_var_h * target_box_data[offset + 3]) *
prior_box_height;
output[offset] = target_box_center_x - target_box_width / 2;
output[offset + 1] = target_box_center_y - target_box_height / 2;
......@@ -157,26 +179,40 @@ class BoxCoderKernel : public framework::OpKernel<T> {
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
std::vector<float> variance = context.Attr<std::vector<float>>("variance");
const int axis = context.Attr<int>("axis");
if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
"Only support 1 level of LoD.");
}
if (prior_box_var) {
PADDLE_ENFORCE(variance.empty(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time.");
}
if (!(variance.empty())) {
PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
"Size of attribute 'variance' should be 4");
}
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
auto row = target_box->dims()[0];
auto col = prior_box->dims()[0];
if (code_type == BoxCodeType::kDecodeCenterSize) {
col = target_box->dims()[1];
}
auto len = prior_box->dims()[1];
output_box->mutable_data<T>({row, col, len}, context.GetPlace());
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
T* output = output_box->data<T>();
if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSize(target_box, prior_box, prior_box_var, normalized,
output);
variance, output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSize(target_box, prior_box, prior_box_var, normalized,
output);
DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis,
variance, output);
}
}
};
......
......@@ -54,6 +54,9 @@ class SliceOp : public framework::OperatorWithKernel {
out_dims[axes[i]] = end - start;
}
ctx->SetOutputDim("Out", out_dims);
if (axes[0] != 0) {
ctx->ShareLoD("Input", /*->*/ "Out");
}
}
protected:
......
......@@ -346,19 +346,107 @@ def box_coder(prior_box,
target_box,
code_type="encode_center_size",
box_normalized=True,
name=None):
name=None,
axis=0):
"""
${comment}
**Box Coder Layer**
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
.. math::
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = \log(\abs(tw / pw)) / pwv
oh = \log(\abs(th / ph)) / phv
The Decoding schema described below:
.. math::
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = \exp(pwv * tw) * pw + tw / 2
oh = \exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates,
width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote
the priorbox's (anchor) center coordinates, width and height. `pxv`,
`pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`,
`ow`, `oh` denote the encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target
box has shape [N, M, 4], and the shape of prior box can be [N, 4] or
[M, 4]. Then prior box will broadcast to target box along the
assigned axis.
Args:
prior_box(${prior_box_type}): ${prior_box_comment}
prior_box_var(${prior_box_var_type}): ${prior_box_var_comment}
target_box(${target_box_type}): ${target_box_comment}
code_type(${code_type_type}): ${code_type_comment}
box_normalized(${box_normalized_type}): ${box_normalized_comment}
prior_box(Variable): Box list prior_box is a 2-D Tensor with shape
[M, 4] holds M boxes, each box is represented as
[xmin, ymin, xmax, ymax], [xmin, ymin] is the
left top coordinate of the anchor box, if the
input is image feature map, they are close to
the origin of the coordinate system. [xmax, ymax]
is the right bottom coordinate of the anchor box.
prior_box_var(Variable|list): prior_box_var supports two types of input.
One is variable with shape [M, 4] holds M group.
The other one is list consist of 4 elements
shared by all boxes.
target_box(Variable): This input can be a 2-D LoDTensor with shape
[N, 4] when code_type is 'encode_center_size'.
This input also can be a 3-D Tensor with shape
[N, M, 4] when code_type is 'decode_center_size'.
Each box is represented as
[xmin, ymin, xmax, ymax]. This tensor can
contain LoD information to represent a batch
of inputs.
code_type(string): The code type used with the target box. It can be
encode_center_size or decode_center_size
box_normalized(int): Whether treat the priorbox as a noramlized box.
Set true by default.
name(string): The name of box coder.
axis(int): Which axis in PriorBox to broadcast for box decode,
for example, if axis is 0 and TargetBox has shape
[N, M, 4] and PriorBox has shape [M, 4], then PriorBox
will broadcast to [N, M, 4] for decoding. It is only valid
when code type is decode_center_size. Set 0 by default.
Returns:
output_box(${output_box_type}): ${output_box_comment}
output_box(Variable): When code_type is 'encode_center_size', the
output tensor of box_coder_op with shape
[N, M, 4] representing the result of N target
boxes encoded with M Prior boxes and variances.
When code_type is 'decode_center_size',
N represents the batch size and M represents
the number of deocded boxes.
Examples:
.. code-block:: python
prior_box = fluid.layers.data(name='prior_box',
shape=[512, 4],
dtype='float32',
append_batch_size=False)
target_box = fluid.layers.data(name='target_box',
shape=[512,81,4],
dtype='float32',
append_batch_size=False)
output = fluid.layers.box_coder(prior_box=prior_box,
prior_box_var=[0.1,0.1,0.2,0.2],
target_box=target_box,
code_type="decode_center_size",
box_normalized=False,
axis=1)
"""
helper = LayerHelper("box_coder", **locals())
......@@ -369,15 +457,22 @@ def box_coder(prior_box,
output_box = helper.create_variable(
name=name, dtype=prior_box.dtype, persistable=False)
inputs = {"PriorBox": prior_box, "TargetBox": target_box}
attrs = {
"code_type": code_type,
"box_normalized": box_normalized,
"axis": axis
}
if isinstance(prior_box_var, Variable):
inputs['PriorBoxVar'] = prior_box_var
elif isinstance(prior_box_var, list):
attrs['variance'] = prior_box_var
else:
raise TypeError("Input variance of box_coder must be Variable or lisz")
helper.append_op(
type="box_coder",
inputs={
"PriorBox": prior_box,
"PriorBoxVar": prior_box_var,
"TargetBox": target_box
},
attrs={"code_type": code_type,
"box_normalized": box_normalized},
inputs=inputs,
attrs=attrs,
outputs={"OutputBox": output_box})
return output_box
......
......@@ -50,6 +50,19 @@ class TestDetection(unittest.TestCase):
self.assertEqual(out.shape[-1], 6)
print(str(program))
def test_box_coder_api(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[4], dtype='float32')
y = layers.data(name='z', shape=[4], dtype='float32', lod_level=1)
bcoder = layers.box_coder(
prior_box=x,
prior_box_var=[0.1, 0.2, 0.1, 0.2],
target_box=y,
code_type='encode_center_size')
self.assertIsNotNone(bcoder)
print(str(program))
def test_detection_api(self):
program = Program()
with program_guard(program):
......
......@@ -21,80 +21,80 @@ import math
from op_test import OpTest
def box_coder(target_box, prior_box, prior_box_var, output_box, code_type,
box_normalized):
prior_box_x = (
(prior_box[:, 2] + prior_box[:, 0]) / 2).reshape(1, prior_box.shape[0])
prior_box_y = (
(prior_box[:, 3] + prior_box[:, 1]) / 2).reshape(1, prior_box.shape[0])
prior_box_width = (
(prior_box[:, 2] - prior_box[:, 0])).reshape(1, prior_box.shape[0])
prior_box_height = (
(prior_box[:, 3] - prior_box[:, 1])).reshape(1, prior_box.shape[0])
prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0],
prior_box_var.shape[1])
if not box_normalized:
prior_box_height = prior_box_height + 1
prior_box_width = prior_box_width + 1
if (code_type == "EncodeCenterSize"):
target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape(
target_box.shape[0], 1)
target_box_y = ((target_box[:, 3] + target_box[:, 1]) / 2).reshape(
target_box.shape[0], 1)
target_box_width = ((target_box[:, 2] - target_box[:, 0])).reshape(
target_box.shape[0], 1)
target_box_height = ((target_box[:, 3] - target_box[:, 1])).reshape(
target_box.shape[0], 1)
if not box_normalized:
target_box_height = target_box_height + 1
target_box_width = target_box_width + 1
output_box[:,:,0] = (target_box_x - prior_box_x) / prior_box_width / \
prior_box_var[:,:,0]
output_box[:,:,1] = (target_box_y - prior_box_y) / prior_box_height / \
prior_box_var[:,:,1]
output_box[:,:,2] = np.log(np.fabs(target_box_width / prior_box_width)) / \
prior_box_var[:,:,2]
output_box[:,:,3] = np.log(np.fabs(target_box_height / prior_box_height)) / \
prior_box_var[:,:,3]
elif (code_type == "DecodeCenterSize"):
target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \
prior_box_width + prior_box_x
target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \
prior_box_height + prior_box_y
target_box_width = np.exp(prior_box_var[:,:,2] * target_box[:,:,2]) * \
prior_box_width
target_box_height = np.exp(prior_box_var[:,:,3] * target_box[:,:,3]) * \
prior_box_height
output_box[:, :, 0] = target_box_x - target_box_width / 2
output_box[:, :, 1] = target_box_y - target_box_height / 2
output_box[:, :, 2] = target_box_x + target_box_width / 2
output_box[:, :, 3] = target_box_y + target_box_height / 2
if not box_normalized:
output_box[:, :, 2] = output_box[:, :, 2] - 1
output_box[:, :, 3] = output_box[:, :, 3] - 1
def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type,
box_normalized):
n = target_box.shape[0]
m = prior_box.shape[0]
def box_decoder(t_box, p_box, pb_v, output_box, norm, axis=0):
pb_w = p_box[:, 2] - p_box[:, 0] + (norm == False)
pb_h = p_box[:, 3] - p_box[:, 1] + (norm == False)
pb_x = pb_w * 0.5 + p_box[:, 0]
pb_y = pb_h * 0.5 + p_box[:, 1]
shape = (1, p_box.shape[0]) if axis == 0 else (p_box.shape[0], 1)
pb_w = pb_w.reshape(shape)
pb_h = pb_h.reshape(shape)
pb_x = pb_x.reshape(shape)
pb_y = pb_y.reshape(shape)
if pb_v.ndim == 2:
pb_v = pb_v.reshape(1, pb_v.shape[0], pb_v.shape[1])
if pb_v.ndim == 1:
tb_x = pb_v[0] * t_box[:, :, 0] * pb_w + pb_x
tb_y = pb_v[1] * t_box[:, :, 1] * pb_h + pb_y
tb_w = np.exp(pb_v[2] * t_box[:, :, 2]) * pb_w
tb_h = np.exp(pb_v[3] * t_box[:, :, 3]) * pb_h
else:
tb_x = pb_v[:, :, 0] * t_box[:, :, 0] * pb_w + pb_x
tb_y = pb_v[:, :, 1] * t_box[:, :, 1] * pb_h + pb_y
tb_w = np.exp(pb_v[:, :, 2] * t_box[:, :, 2]) * pb_w
tb_h = np.exp(pb_v[:, :, 3] * t_box[:, :, 3]) * pb_h
output_box[:, :, 0] = tb_x - tb_w / 2
output_box[:, :, 1] = tb_y - tb_h / 2
output_box[:, :, 2] = tb_x + tb_w / 2 - (not norm)
output_box[:, :, 3] = tb_y + tb_h / 2 - (not norm)
def box_encoder(t_box, p_box, pb_v, output_box, norm):
pb_w = p_box[:, 2] - p_box[:, 0] + (norm == False)
pb_h = p_box[:, 3] - p_box[:, 1] + (norm == False)
pb_x = pb_w * 0.5 + p_box[:, 0]
pb_y = pb_h * 0.5 + p_box[:, 1]
shape = (1, p_box.shape[0])
pb_w = pb_w.reshape(shape)
pb_h = pb_h.reshape(shape)
pb_x = pb_x.reshape(shape)
pb_y = pb_y.reshape(shape)
if pb_v.ndim == 2:
pb_v = pb_v.reshape(1, pb_v.shape[0], pb_v.shape[1])
tb_x = ((t_box[:, 2] + t_box[:, 0]) / 2).reshape(t_box.shape[0], 1)
tb_y = ((t_box[:, 3] + t_box[:, 1]) / 2).reshape(t_box.shape[0], 1)
tb_w = (t_box[:, 2] - t_box[:, 0]).reshape(t_box.shape[0], 1) + (not norm)
tb_h = (t_box[:, 3] - t_box[:, 1]).reshape(t_box.shape[0], 1) + (not norm)
if pb_v.ndim == 1:
output_box[:, :, 0] = (tb_x - pb_x) / pb_w / pb_v[0]
output_box[:, :, 1] = (tb_y - pb_y) / pb_h / pb_v[1]
output_box[:, :, 2] = np.log(np.fabs(tb_w / pb_w)) / pb_v[2]
output_box[:, :, 3] = np.log(np.fabs(tb_h / pb_h)) / pb_v[3]
else:
output_box[:, :, 0] = (tb_x - pb_x) / pb_w / pb_v[:, :, 0]
output_box[:, :, 1] = (tb_y - pb_y) / pb_h / pb_v[:, :, 1]
output_box[:, :, 2] = np.log(np.fabs(tb_w / pb_w)) / pb_v[:, :, 2]
output_box[:, :, 3] = np.log(np.fabs(tb_h / pb_h)) / pb_v[:, :, 3]
def batch_box_coder(p_box, pb_v, t_box, lod, code_type, norm, axis=0):
n = t_box.shape[0]
m = p_box.shape[0]
if code_type == "DecodeCenterSize":
m = t_box.shape[1]
output_box = np.zeros((n, m, 4), dtype=np.float32)
cur_offset = 0
for i in range(len(lod)):
if (code_type == "EncodeCenterSize"):
box_coder(target_box[cur_offset:(cur_offset + lod[i]), :],
prior_box, prior_box_var,
output_box[cur_offset:(cur_offset + lod[i]), :, :],
code_type, box_normalized)
box_encoder(t_box[cur_offset:(cur_offset + lod[i]), :], p_box, pb_v,
output_box[cur_offset:(cur_offset + lod[i]), :, :],
norm)
elif (code_type == "DecodeCenterSize"):
box_coder(target_box[cur_offset:(cur_offset + lod[i]), :, :],
prior_box, prior_box_var,
output_box[cur_offset:(cur_offset + lod[i]), :, :],
code_type, box_normalized)
box_decoder(t_box, p_box, pb_v, output_box, norm, axis)
cur_offset += lod[i]
return output_box
......@@ -106,9 +106,35 @@ class TestBoxCoderOp(OpTest):
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((10, 4)).astype('float32')
prior_box_var = np.random.random((10, 4)).astype('float32')
target_box = np.random.random((5, 10, 4)).astype('float32')
prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.random.random((81, 4)).astype('float32')
target_box = np.random.random((20, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type, box_normalized)
self.inputs = {
'PriorBox': prior_box,
'PriorBoxVar': prior_box_var,
'TargetBox': target_box,
}
self.attrs = {
'code_type': 'decode_center_size',
'box_normalized': False
}
self.outputs = {'OutputBox': output_box}
class TestBoxCoderOpWithOneRankVar(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((20, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
......@@ -133,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest):
def setUp(self):
self.op_type = "box_coder"
lod = [[0, 1, 2, 3, 4, 5]]
prior_box = np.random.random((10, 4)).astype('float32')
prior_box_var = np.ones((10, 4)).astype('float32')
target_box = np.random.random((5, 10, 4)).astype('float32')
prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.ones((81, 4)).astype('float32')
target_box = np.random.random((20, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
......@@ -158,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest):
def setUp(self):
self.op_type = "box_coder"
lod = [[4, 8, 8]]
prior_box = np.random.random((10, 4)).astype('float32')
prior_box_var = np.random.random((10, 4)).astype('float32')
target_box = np.random.random((20, 4)).astype('float32')
lod = [[10, 20, 20]]
prior_box = np.random.random((20, 4)).astype('float32')
prior_box_var = np.random.random((20, 4)).astype('float32')
target_box = np.random.random((50, 4)).astype('float32')
code_type = "EncodeCenterSize"
box_normalized = True
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
......@@ -176,5 +202,63 @@ class TestBoxCoderOpWithLoD(OpTest):
self.outputs = {'OutputBox': output_box}
class TestBoxCoderOpWithAxis(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((30, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((30, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
axis = 1
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type, box_normalized, axis)
self.inputs = {
'PriorBox': prior_box,
'PriorBoxVar': prior_box_var,
'TargetBox': target_box,
}
self.attrs = {
'code_type': 'decode_center_size',
'box_normalized': False,
'axis': axis
}
self.outputs = {'OutputBox': output_box}
class TestBoxCoderOpWithVariance(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((30, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((30, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
axis = 1
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type, box_normalized, axis)
self.inputs = {
'PriorBox': prior_box,
'TargetBox': target_box,
}
self.attrs = {
'code_type': 'decode_center_size',
'box_normalized': False,
'variance': prior_box_var.astype(np.float).flatten(),
'axis': axis
}
self.outputs = {'OutputBox': output_box}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册