提交 10dd3b37 编写于 作者: J jerrywgz

add axis for box coder op

上级 3f815e07
...@@ -315,7 +315,7 @@ paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'tr ...@@ -315,7 +315,7 @@ paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'tr
paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)) paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True))
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'axis', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, 0, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
......
...@@ -32,31 +32,53 @@ class BoxCoderOp : public framework::OperatorWithKernel { ...@@ -32,31 +32,53 @@ class BoxCoderOp : public framework::OperatorWithKernel {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2,
"The rank of Input of PriorBoxVar must be 2"); "The rank of Input of PriorBox must be 2");
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, PADDLE_ENFORCE_EQ(prior_box_dims[1], 4,
"The shape of PriorBox is [N, 4]"); "The shape of PriorBox is [N, 4]");
if (ctx->HasInput("PriorBoxVar")) { if (ctx->HasInput("PriorBoxVar")) {
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims); PADDLE_ENFORCE(
prior_box_var_dims.size() == 1 || prior_box_var_dims.size() == 2,
"Input(PriorBoxVar) of BoxCoderOp should be 1 or 2.");
if (prior_box_var_dims.size() == 1) {
PADDLE_ENFORCE_EQ(
prior_box_var_dims[0], 4,
"The 1st dimension of Input(PriorBoxVar) should be 1"
"when the rank is 1.");
} else {
PADDLE_ENFORCE_EQ(
prior_box_dims, prior_box_var_dims,
"The dimension of Input(PriorBoxVar) should be equal to"
"the dimension of Input(PriorBox when the rank is 2.)");
}
} }
auto code_type = auto code_type =
GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type")); GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
int axis = ctx->Attrs().Get<int>("axis");
if (code_type == BoxCodeType::kEncodeCenterSize) { if (code_type == BoxCodeType::kEncodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
"The rank of Input of TargetBox must be 2"); "The rank of Input of TargetBox must be 2");
PADDLE_ENFORCE_EQ(target_box_dims[1], 4, PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
"The shape of TargetBox is [M, 4]"); "The shape of TargetBox is [M, 4]");
ctx->SetOutputDim(
"OutputBox",
framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
} else if (code_type == BoxCodeType::kDecodeCenterSize) { } else if (code_type == BoxCodeType::kDecodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
"The rank of Input of TargetBox must be 3"); "The rank of Input of TargetBox must be 3");
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); if (axis == 0) {
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
} else if (axis == 1) {
PADDLE_ENFORCE_EQ(target_box_dims[0], prior_box_dims[0]);
} else {
PADDLE_THROW("axis must be 0 or 1.");
}
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
ctx->ShareDim("TargetBox", /*->*/ "OutputBox");
} }
} }
ctx->SetOutputDim(
"OutputBox",
framework::make_ddim({target_box_dims[0], prior_box_dims[0], 4}));
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");
} }
}; };
...@@ -100,6 +122,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -100,6 +122,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default true) " "(bool, default true) "
"whether treat the priorbox as a noramlized box") "whether treat the priorbox as a noramlized box")
.SetDefault(true); .SetDefault(true);
AddAttr<int>("axis",
"(int, default 1)"
"which axis to broadcast for box decode, it is only valid"
"when code type is decode_center_size")
.SetDefault(0)
.InEnum({0, 1});
AddOutput("OutputBox", AddOutput("OutputBox",
"(LoDTensor or Tensor) " "(LoDTensor or Tensor) "
"When code_type is 'encode_center_size', the output tensor of " "When code_type is 'encode_center_size', the output tensor of "
......
...@@ -20,7 +20,8 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, ...@@ -20,7 +20,8 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const T* target_box_data, const int row,
const int col, const int len, const int col, const int len,
const bool normalized, T* output) { const bool normalized,
const T prior_box_var_size, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) { if (idx < row * col) {
const int row_idx = idx / col; const int row_idx = idx / col;
...@@ -30,11 +31,9 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, ...@@ -30,11 +31,9 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
T prior_box_height = prior_box_data[col_idx * len + 3] - T prior_box_height = prior_box_data[col_idx * len + 3] -
prior_box_data[col_idx * len + 1] + prior_box_data[col_idx * len + 1] +
(normalized == false); (normalized == false);
T prior_box_center_x = T prior_box_center_x = prior_box_data[col_idx * len] + prior_box_width / 2;
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; T prior_box_center_y =
T prior_box_center_y = (prior_box_data[col_idx * len + 3] + prior_box_data[col_idx * len + 1] + prior_box_height / 2;
prior_box_data[col_idx * len + 1]) /
2;
T target_box_center_x = T target_box_center_x =
(target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) / (target_box_data[row_idx * len + 2] + target_box_data[row_idx * len]) /
...@@ -55,10 +54,14 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, ...@@ -55,10 +54,14 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)); output[idx * len + 2] = log(fabs(target_box_width / prior_box_width));
output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)); output[idx * len + 3] = log(fabs(target_box_height / prior_box_height));
if (prior_box_var_data) { if (prior_box_var_data) {
output[idx * len] /= prior_box_var_data[col_idx * len]; int prior_var_offset = 0;
output[idx * len + 1] /= prior_box_var_data[col_idx * len + 1]; if (prior_box_var_size == 2) {
output[idx * len + 2] /= prior_box_var_data[col_idx * len + 2]; prior_var_offset = col_idx * len;
output[idx * len + 3] /= prior_box_var_data[col_idx * len + 3]; }
output[idx * len] /= prior_box_var_data[prior_var_offset];
output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1];
output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2];
output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3];
} }
} }
} }
...@@ -68,33 +71,48 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, ...@@ -68,33 +71,48 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const T* target_box_data, const int row,
const int col, const int len, const int col, const int len,
const bool normalized, T* output) { const bool normalized,
const T prior_box_var_size,
const int axis, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; const int idx = threadIdx.x + blockIdx.x * blockDim.x;
int prior_box_offset = 0;
if (idx < row * col) { if (idx < row * col) {
const int col_idx = idx % col; const int col_idx = idx % col;
T prior_box_width = prior_box_data[col_idx * len + 2] - const int row_idx = idx / col;
prior_box_data[col_idx * len] + (normalized == false); if (axis == 0)
T prior_box_height = prior_box_data[col_idx * len + 3] - prior_box_offset = col_idx * len;
prior_box_data[col_idx * len + 1] + else if (axis == 1)
prior_box_offset = row_idx * len;
T prior_box_width = prior_box_data[prior_box_offset + 2] -
prior_box_data[prior_box_offset] +
(normalized == false);
T prior_box_height = prior_box_data[prior_box_offset + 3] -
prior_box_data[prior_box_offset + 1] +
(normalized == false); (normalized == false);
T prior_box_center_x = T prior_box_center_x =
(prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; prior_box_data[prior_box_offset] + prior_box_width / 2;
T prior_box_center_y = (prior_box_data[col_idx * len + 3] + T prior_box_center_y =
prior_box_data[col_idx * len + 1]) / prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
2;
T target_box_width, target_box_height; T target_box_width, target_box_height;
T target_box_center_x, target_box_center_y; T target_box_center_x, target_box_center_y;
if (prior_box_var_data) { if (prior_box_var_data) {
target_box_width = exp(prior_box_var_data[col_idx * len + 2] * int prior_var_offset = 0;
if (prior_box_var_size == 2) {
if (axis == 0)
prior_var_offset = col_idx * len;
else if (axis == 1)
prior_var_offset = row_idx * len;
}
target_box_width = exp(prior_box_var_data[prior_var_offset + 2] *
target_box_data[idx * len + 2]) * target_box_data[idx * len + 2]) *
prior_box_width; prior_box_width;
target_box_height = exp(prior_box_var_data[col_idx * len + 3] * target_box_height = exp(prior_box_var_data[prior_var_offset + 3] *
target_box_data[idx * len + 3]) * target_box_data[idx * len + 3]) *
prior_box_height; prior_box_height;
target_box_center_x = prior_box_var_data[col_idx * len] * target_box_center_x = prior_box_var_data[prior_var_offset] *
target_box_data[idx * len] * prior_box_width + target_box_data[idx * len] * prior_box_width +
prior_box_center_x; prior_box_center_x;
target_box_center_y = prior_box_var_data[col_idx * len + 1] * target_box_center_y = prior_box_var_data[prior_var_offset + 1] *
target_box_data[idx * len + 1] * target_box_data[idx * len + 1] *
prior_box_height + prior_box_height +
prior_box_center_y; prior_box_center_y;
...@@ -131,14 +149,25 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -131,14 +149,25 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
const T* prior_box_data = prior_box->data<T>(); const T* prior_box_data = prior_box->data<T>();
const T* target_box_data = target_box->data<T>(); const T* target_box_data = target_box->data<T>();
const T* prior_box_var_data = nullptr; const T* prior_box_var_data = nullptr;
if (prior_box_var) prior_box_var_data = prior_box_var->data<T>(); auto prior_box_var_size = 0;
if (prior_box_var) {
prior_box_var_data = prior_box_var->data<T>();
prior_box_var_size = prior_box_var->dims().size();
}
if (target_box->lod().size()) { if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1, PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
"Only support 1 level of LoD."); "Only support 1 level of LoD.");
} }
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
int axis = context.Attr<int>("axis");
auto row = target_box->dims()[0]; auto row = target_box->dims()[0];
auto col = prior_box->dims()[0]; auto col = prior_box->dims()[0];
if (code_type == BoxCodeType::kDecodeCenterSize) {
col = target_box->dims()[1];
}
auto len = prior_box->dims()[1]; auto len = prior_box->dims()[1];
int block = 512; int block = 512;
int grid = (row * col + block - 1) / block; int grid = (row * col + block - 1) / block;
...@@ -147,16 +176,14 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> { ...@@ -147,16 +176,14 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
output_box->mutable_data<T>({row, col, len}, context.GetPlace()); output_box->mutable_data<T>({row, col, len}, context.GetPlace());
T* output = output_box->data<T>(); T* output = output_box->data<T>();
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
if (code_type == BoxCodeType::kEncodeCenterSize) { if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len, prior_box_data, prior_box_var_data, target_box_data, row, col, len,
normalized, output); normalized, prior_box_var_size, output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) { } else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>( DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len, prior_box_data, prior_box_var_data, target_box_data, row, col, len,
normalized, output); normalized, prior_box_var_size, axis, output);
} }
} }
}; };
......
...@@ -53,10 +53,9 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -53,10 +53,9 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T prior_box_height = prior_box_data[j * len + 3] - T prior_box_height = prior_box_data[j * len + 3] -
prior_box_data[j * len + 1] + prior_box_data[j * len + 1] +
(normalized == false); (normalized == false);
T prior_box_center_x = T prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2;
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
T prior_box_center_y = T prior_box_center_y =
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; prior_box_data[j * len + 1] + prior_box_height / 2;
T target_box_center_x = T target_box_center_x =
(target_box_data[i * len + 2] + target_box_data[i * len]) / 2; (target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
...@@ -78,10 +77,14 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -78,10 +77,14 @@ class BoxCoderKernel : public framework::OpKernel<T> {
output[offset + 3] = output[offset + 3] =
std::log(std::fabs(target_box_height / prior_box_height)); std::log(std::fabs(target_box_height / prior_box_height));
if (prior_box_var) { if (prior_box_var) {
output[offset] /= prior_box_var_data[j * len]; int prior_var_offset = 0;
output[offset + 1] /= prior_box_var_data[j * len + 1]; if (prior_box_var->dims().size() == 2) {
output[offset + 2] /= prior_box_var_data[j * len + 2]; prior_var_offset = j * len;
output[offset + 3] /= prior_box_var_data[j * len + 3]; }
output[offset] /= prior_box_var_data[prior_var_offset];
output[offset + 1] /= prior_box_var_data[prior_var_offset + 1];
output[offset + 2] /= prior_box_var_data[prior_var_offset + 2];
output[offset + 3] /= prior_box_var_data[prior_var_offset + 3];
} }
} }
} }
...@@ -89,48 +92,63 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -89,48 +92,63 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void DecodeCenterSize(const framework::Tensor* target_box, void DecodeCenterSize(const framework::Tensor* target_box,
const framework::Tensor* prior_box, const framework::Tensor* prior_box,
const framework::Tensor* prior_box_var, const framework::Tensor* prior_box_var,
const bool normalized, T* output) const { const bool normalized, const int axis,
T* output) const {
int64_t row = target_box->dims()[0]; int64_t row = target_box->dims()[0];
int64_t col = prior_box->dims()[0]; int64_t col = target_box->dims()[1];
int64_t len = prior_box->dims()[1]; int64_t len = target_box->dims()[2];
auto* target_box_data = target_box->data<T>(); auto* target_box_data = target_box->data<T>();
auto* prior_box_data = prior_box->data<T>(); auto* prior_box_data = prior_box->data<T>();
const T* prior_box_var_data = nullptr; const T* prior_box_var_data = nullptr;
if (prior_box_var) prior_box_var_data = prior_box_var->data<T>(); if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
int prior_box_offset = 0;
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
#endif #endif
for (int64_t i = 0; i < row; ++i) { for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) { for (int64_t j = 0; j < col; ++j) {
size_t offset = i * col * len + j * len; size_t offset = i * col * len + j * len;
T prior_box_width = prior_box_data[j * len + 2] - if (axis == 0) {
prior_box_data[j * len] + (normalized == false); prior_box_offset = j * len;
T prior_box_height = prior_box_data[j * len + 3] - } else if (axis == 1) {
prior_box_data[j * len + 1] + prior_box_offset = i * len;
}
T prior_box_width = prior_box_data[prior_box_offset + 2] -
prior_box_data[prior_box_offset] +
(normalized == false);
T prior_box_height = prior_box_data[prior_box_offset + 3] -
prior_box_data[prior_box_offset + 1] +
(normalized == false); (normalized == false);
T prior_box_center_x = T prior_box_center_x =
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; prior_box_data[prior_box_offset] + prior_box_width / 2;
T prior_box_center_y = T prior_box_center_y =
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
T target_box_center_x = 0, target_box_center_y = 0; T target_box_center_x = 0, target_box_center_y = 0;
T target_box_width = 0, target_box_height = 0; T target_box_width = 0, target_box_height = 0;
if (prior_box_var) { if (prior_box_var) {
target_box_center_x = prior_box_var_data[j * len] * int prior_var_offset = 0;
if (prior_box_var->dims().size() == 2) {
if (axis == 0)
prior_var_offset = j * len;
else if (axis == 1)
prior_var_offset = i * len;
}
target_box_center_x = prior_box_var_data[prior_var_offset] *
target_box_data[offset] * prior_box_width + target_box_data[offset] * prior_box_width +
prior_box_center_x; prior_box_center_x;
target_box_center_y = prior_box_var_data[j * len + 1] * target_box_center_y = prior_box_var_data[prior_var_offset + 1] *
target_box_data[offset + 1] * target_box_data[offset + 1] *
prior_box_height + prior_box_height +
prior_box_center_y; prior_box_center_y;
target_box_width = std::exp(prior_box_var_data[j * len + 2] * target_box_width = std::exp(prior_box_var_data[prior_var_offset + 2] *
target_box_data[offset + 2]) * target_box_data[offset + 2]) *
prior_box_width; prior_box_width;
target_box_height = std::exp(prior_box_var_data[j * len + 3] * target_box_height =
target_box_data[offset + 3]) * std::exp(prior_box_var_data[prior_var_offset + 3] *
prior_box_height; target_box_data[offset + 3]) *
prior_box_height;
} else { } else {
target_box_center_x = target_box_center_x =
target_box_data[offset] * prior_box_width + prior_box_center_x; target_box_data[offset] * prior_box_width + prior_box_center_x;
...@@ -157,25 +175,29 @@ class BoxCoderKernel : public framework::OpKernel<T> { ...@@ -157,25 +175,29 @@ class BoxCoderKernel : public framework::OpKernel<T> {
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar"); auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox"); auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox"); auto* output_box = context.Output<framework::Tensor>("OutputBox");
const int axis = context.Attr<int>("axis");
if (target_box->lod().size()) { if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
"Only support 1 level of LoD."); "Only support 1 level of LoD.");
} }
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
auto row = target_box->dims()[0]; auto row = target_box->dims()[0];
auto col = prior_box->dims()[0]; auto col = prior_box->dims()[0];
if (code_type == BoxCodeType::kDecodeCenterSize) {
col = target_box->dims()[1];
}
auto len = prior_box->dims()[1]; auto len = prior_box->dims()[1];
output_box->mutable_data<T>({row, col, len}, context.GetPlace()); output_box->mutable_data<T>({row, col, len}, context.GetPlace());
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
T* output = output_box->data<T>(); T* output = output_box->data<T>();
if (code_type == BoxCodeType::kEncodeCenterSize) { if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSize(target_box, prior_box, prior_box_var, normalized, EncodeCenterSize(target_box, prior_box, prior_box_var, normalized,
output); output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) { } else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis,
output); output);
} }
} }
......
...@@ -342,6 +342,7 @@ def box_coder(prior_box, ...@@ -342,6 +342,7 @@ def box_coder(prior_box,
target_box, target_box,
code_type="encode_center_size", code_type="encode_center_size",
box_normalized=True, box_normalized=True,
axis=0,
name=None): name=None):
""" """
${comment} ${comment}
...@@ -352,6 +353,7 @@ def box_coder(prior_box, ...@@ -352,6 +353,7 @@ def box_coder(prior_box,
target_box(${target_box_type}): ${target_box_comment} target_box(${target_box_type}): ${target_box_comment}
code_type(${code_type_type}): ${code_type_comment} code_type(${code_type_type}): ${code_type_comment}
box_normalized(${box_normalized_type}): ${box_normalized_comment} box_normalized(${box_normalized_type}): ${box_normalized_comment}
axis(${axis_type}): ${axis_comment}
Returns: Returns:
output_box(${output_box_type}): ${output_box_comment} output_box(${output_box_type}): ${output_box_comment}
...@@ -372,8 +374,11 @@ def box_coder(prior_box, ...@@ -372,8 +374,11 @@ def box_coder(prior_box,
"PriorBoxVar": prior_box_var, "PriorBoxVar": prior_box_var,
"TargetBox": target_box "TargetBox": target_box
}, },
attrs={"code_type": code_type, attrs={
"box_normalized": box_normalized}, "code_type": code_type,
"box_normalized": box_normalized,
"axis": axis
},
outputs={"OutputBox": output_box}) outputs={"OutputBox": output_box})
return output_box return output_box
......
...@@ -21,22 +21,32 @@ import math ...@@ -21,22 +21,32 @@ import math
from op_test import OpTest from op_test import OpTest
def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, def box_coder(target_box,
box_normalized): prior_box,
prior_box_x = ( prior_box_var,
(prior_box[:, 2] + prior_box[:, 0]) / 2).reshape(1, prior_box.shape[0]) output_box,
prior_box_y = ( code_type,
(prior_box[:, 3] + prior_box[:, 1]) / 2).reshape(1, prior_box.shape[0]) box_normalized,
prior_box_width = ( axis=0):
(prior_box[:, 2] - prior_box[:, 0])).reshape(1, prior_box.shape[0]) prior_box_width = prior_box[:, 2] - prior_box[:, 0] + \
prior_box_height = ( (box_normalized==False)
(prior_box[:, 3] - prior_box[:, 1])).reshape(1, prior_box.shape[0]) prior_box_height = prior_box[:, 3] - prior_box[:, 1] + \
prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0], (box_normalized==False)
prior_box_var.shape[1]) prior_box_x = prior_box_width * 0.5 + prior_box[:, 0]
if not box_normalized: prior_box_y = prior_box_height * 0.5 + prior_box[:, 1]
prior_box_height = prior_box_height + 1 if axis == 0:
prior_box_width = prior_box_width + 1 prior_box_width = prior_box_width.reshape(1, prior_box.shape[0])
prior_box_height = prior_box_height.reshape(1, prior_box.shape[0])
prior_box_x = prior_box_x.reshape(1, prior_box.shape[0])
prior_box_y = prior_box_y.reshape(1, prior_box.shape[0])
else:
prior_box_width = prior_box_width.reshape(prior_box.shape[0], 1)
prior_box_height = prior_box_height.reshape(prior_box.shape[0], 1)
prior_box_x = prior_box_x.reshape(prior_box.shape[0], 1)
prior_box_y = prior_box_y.reshape(prior_box.shape[0], 1)
if prior_box_var.ndim == 2:
prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0],
prior_box_var.shape[1])
if (code_type == "EncodeCenterSize"): if (code_type == "EncodeCenterSize"):
target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape( target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape(
target_box.shape[0], 1) target_box.shape[0], 1)
...@@ -49,26 +59,52 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, ...@@ -49,26 +59,52 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type,
if not box_normalized: if not box_normalized:
target_box_height = target_box_height + 1 target_box_height = target_box_height + 1
target_box_width = target_box_width + 1 target_box_width = target_box_width + 1
if prior_box_var.ndim == 1:
output_box[:,:,0] = (target_box_x - prior_box_x) / prior_box_width / \ output_box[:,:,0] = (target_box_x - prior_box_x) / \
prior_box_var[:,:,0] prior_box_width / \
output_box[:,:,1] = (target_box_y - prior_box_y) / prior_box_height / \ prior_box_var[0]
prior_box_var[:,:,1] output_box[:,:,1] = (target_box_y - prior_box_y) / \
output_box[:,:,2] = np.log(np.fabs(target_box_width / prior_box_width)) / \ prior_box_height / \
prior_box_var[:,:,2] prior_box_var[1]
output_box[:,:,3] = np.log(np.fabs(target_box_height / prior_box_height)) / \ output_box[:,:,2] = np.log(np.fabs(target_box_width / \
prior_box_var[:,:,3] prior_box_width)) / \
prior_box_var[2]
output_box[:,:,3] = np.log(np.fabs(target_box_height / \
prior_box_height)) / \
prior_box_var[3]
else:
output_box[:,:,0] = (target_box_x - prior_box_x) / \
prior_box_width / \
prior_box_var[:,:,0]
output_box[:,:,1] = (target_box_y - prior_box_y) / \
prior_box_height / \
prior_box_var[:,:,1]
output_box[:,:,2] = np.log(np.fabs(target_box_width / \
prior_box_width)) / \
prior_box_var[:,:,2]
output_box[:,:,3] = np.log(np.fabs(target_box_height / \
prior_box_height)) / \
prior_box_var[:,:,3]
elif (code_type == "DecodeCenterSize"): elif (code_type == "DecodeCenterSize"):
target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \ if prior_box_var.ndim == 1:
prior_box_width + prior_box_x target_box_x = prior_box_var[0] * target_box[:,:,0] * \
target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \ prior_box_width + prior_box_x
prior_box_height + prior_box_y target_box_y = prior_box_var[1] * target_box[:,:,1] * \
target_box_width = np.exp(prior_box_var[:,:,2] * target_box[:,:,2]) * \ prior_box_height + prior_box_y
prior_box_width target_box_width = np.exp(prior_box_var[2] * target_box[:,:,2]) * \
target_box_height = np.exp(prior_box_var[:,:,3] * target_box[:,:,3]) * \ prior_box_width
prior_box_height target_box_height = np.exp(prior_box_var[3] * target_box[:,:,3]) * \
prior_box_height
else:
target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \
prior_box_width + prior_box_x
target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \
prior_box_height + prior_box_y
target_box_width = np.exp(prior_box_var[:,:,2] * \
target_box[:,:,2]) * prior_box_width
target_box_height = np.exp(prior_box_var[:,:,3] * \
target_box[:,:,3]) * prior_box_height
output_box[:, :, 0] = target_box_x - target_box_width / 2 output_box[:, :, 0] = target_box_x - target_box_width / 2
output_box[:, :, 1] = target_box_y - target_box_height / 2 output_box[:, :, 1] = target_box_y - target_box_height / 2
output_box[:, :, 2] = target_box_x + target_box_width / 2 output_box[:, :, 2] = target_box_x + target_box_width / 2
...@@ -78,10 +114,17 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, ...@@ -78,10 +114,17 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type,
output_box[:, :, 3] = output_box[:, :, 3] - 1 output_box[:, :, 3] = output_box[:, :, 3] - 1
def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type, def batch_box_coder(prior_box,
box_normalized): prior_box_var,
target_box,
lod,
code_type,
box_normalized,
axis=0):
n = target_box.shape[0] n = target_box.shape[0]
m = prior_box.shape[0] m = prior_box.shape[0]
if code_type == "DecodeCenterSize":
m = target_box.shape[1]
output_box = np.zeros((n, m, 4), dtype=np.float32) output_box = np.zeros((n, m, 4), dtype=np.float32)
cur_offset = 0 cur_offset = 0
for i in range(len(lod)): for i in range(len(lod)):
...@@ -91,10 +134,8 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type, ...@@ -91,10 +134,8 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type,
output_box[cur_offset:(cur_offset + lod[i]), :, :], output_box[cur_offset:(cur_offset + lod[i]), :, :],
code_type, box_normalized) code_type, box_normalized)
elif (code_type == "DecodeCenterSize"): elif (code_type == "DecodeCenterSize"):
box_coder(target_box[cur_offset:(cur_offset + lod[i]), :, :], box_coder(target_box, prior_box, prior_box_var, output_box,
prior_box, prior_box_var, code_type, box_normalized, axis)
output_box[cur_offset:(cur_offset + lod[i]), :, :],
code_type, box_normalized)
cur_offset += lod[i] cur_offset += lod[i]
return output_box return output_box
...@@ -111,6 +152,32 @@ class TestBoxCoderOp(OpTest): ...@@ -111,6 +152,32 @@ class TestBoxCoderOp(OpTest):
target_box = np.random.random((5, 10, 4)).astype('float32') target_box = np.random.random((5, 10, 4)).astype('float32')
code_type = "DecodeCenterSize" code_type = "DecodeCenterSize"
box_normalized = False box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type, box_normalized)
self.inputs = {
'PriorBox': prior_box,
'PriorBoxVar': prior_box_var,
'TargetBox': target_box,
}
self.attrs = {
'code_type': 'decode_center_size',
'box_normalized': False
}
self.outputs = {'OutputBox': output_box}
class TestBoxCoderOpWithOneRankVar(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((6, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((3, 6, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box, output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type, box_normalized) lod[0], code_type, box_normalized)
...@@ -176,5 +243,34 @@ class TestBoxCoderOpWithLoD(OpTest): ...@@ -176,5 +243,34 @@ class TestBoxCoderOpWithLoD(OpTest):
self.outputs = {'OutputBox': output_box} self.outputs = {'OutputBox': output_box}
class TestBoxCoderOpWithAxis(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((5, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((5, 6, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
axis = 1
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type, box_normalized, axis)
self.inputs = {
'PriorBox': prior_box,
'PriorBoxVar': prior_box_var,
'TargetBox': target_box,
}
self.attrs = {
'code_type': 'decode_center_size',
'box_normalized': False,
'axis': axis
}
self.outputs = {'OutputBox': output_box}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册