提交 a2e83d1d 编写于 作者: J jerrywgz 提交者: ceci3

add box_coder_and_assign, test=develop

上级 69859718
......@@ -33,6 +33,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_decoder_and_assign_op.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
class BoxDecoderAndAssignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("PriorBox"),
"Input(PriorBox) of BoxDecoderAndAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("PriorBoxVar"),
"Input(PriorBoxVar) of BoxDecoderAndAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("TargetBox"),
"Input(TargetBox) of BoxDecoderAndAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("BoxScore"),
"Input(BoxScore) of BoxDecoderAndAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutputBox"),
"Output(OutputBox) of BoxDecoderAndAssignOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutputAssignBox"),
"Output(OutputAssignBox) of BoxDecoderAndAssignOp should not be null.");
auto prior_box_dims = ctx->GetInputDim("PriorBox");
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
auto target_box_dims = ctx->GetInputDim("TargetBox");
auto box_score_dims = ctx->GetInputDim("BoxScore");
PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2,
"The rank of Input of PriorBox must be 2");
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]");
PADDLE_ENFORCE_EQ(prior_box_var_dims.size(), 1,
"The rank of Input of PriorBoxVar must be 1");
PADDLE_ENFORCE_EQ(prior_box_var_dims[0], 4,
"The shape of PriorBoxVar is [4]");
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
"The rank of Input of TargetBox must be 2");
PADDLE_ENFORCE_EQ(box_score_dims.size(), 2,
"The rank of Input of BoxScore must be 2");
PADDLE_ENFORCE_EQ(prior_box_dims[0], target_box_dims[0],
"The first dim of prior_box and target_box is roi nums "
"and should be same!");
PADDLE_ENFORCE_EQ(prior_box_dims[0], box_score_dims[0],
"The first dim of prior_box and box_score is roi nums "
"and should be same!");
PADDLE_ENFORCE_EQ(target_box_dims[1], box_score_dims[1] * prior_box_dims[1],
"The shape of target_box is [N, classnum * 4], The shape "
"of box_score is [N, classnum], The shape of prior_box "
"is [N, 4]");
ctx->SetOutputDim("OutputBox", framework::make_ddim({target_box_dims[0],
target_box_dims[1]}));
ctx->ShareLoD("TargetBox", /*->*/ "OutputBox");
ctx->SetOutputDim(
"OutputAssignBox",
framework::make_ddim({prior_box_dims[0], prior_box_dims[1]}));
ctx->ShareLoD("PriorBox", /*->*/ "OutputAssignBox");
}
};
class BoxDecoderAndAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"PriorBox",
"(Tensor, default Tensor<float>) "
"Box list PriorBox is a 2-D Tensor with shape [M, 4] holds M boxes, "
"each box is represented as [xmin, ymin, xmax, ymax], "
"[xmin, ymin] is the left top coordinate of the anchor box, "
"if the input is image feature map, they are close to the origin "
"of the coordinate system. [xmax, ymax] is the right bottom "
"coordinate of the anchor box.");
AddInput("PriorBoxVar",
"(Tensor, default Tensor<float>, optional) "
"PriorBoxVar is a 2-D Tensor with shape [M, 4] holds M group "
"of variance. PriorBoxVar will set all elements to 1 by "
"default.")
.AsDispensable();
AddInput(
"TargetBox",
"(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape "
"[N, classnum*4]. [N, classnum*4], each box is represented as "
"[xmin, ymin, xmax, ymax], [xmin, ymin] is the left top coordinate "
"of the box if the input is image feature map, they are close to "
"the origin of the coordinate system. [xmax, ymax] is the right "
"bottom coordinate of the box. This tensor can contain LoD "
"information to represent a batch of inputs. One instance of this "
"batch can contain different numbers of entities.");
AddInput(
"BoxScore",
"(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape "
"[N, classnum], each box is represented as [classnum] which is "
"the classification probabilities.");
AddAttr<float>("box_clip",
"(float, default 4.135, np.log(1000. / 16.)) "
"clip box to prevent overflowing")
.SetDefault(4.135f);
AddOutput("OutputBox",
"(LoDTensor or Tensor) "
"the output tensor of op with shape [N, classnum * 4] "
"representing the result of N target boxes decoded with "
"M Prior boxes and variances for each class.");
AddOutput("OutputAssignBox",
"(LoDTensor or Tensor) "
"the output tensor of op with shape [N, 4] "
"representing the result of N target boxes decoded with "
"M Prior boxes and variances with the best non-background class "
"by BoxScore.");
AddComment(R"DOC(
Bounding Box Coder.
Decode the target bounding box with the priorbox information.
The Decoding schema described below:
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = exp(pwv * tw) * pw + tw / 2
oh = exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width
and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the
priorbox's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`,
`phv` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh` denote the
encoded/decoded coordinates, width and height.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(box_decoder_and_assign, ops::BoxDecoderAndAssignOp,
ops::BoxDecoderAndAssignOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
box_decoder_and_assign,
ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, float>,
ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/box_decoder_and_assign_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void DecodeBoxKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data, const int roi_num,
const int class_num, const T box_clip,
T* output_box_data) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < roi_num * class_num) {
int i = idx / class_num;
int j = idx % class_num;
T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1;
T prior_box_height =
prior_box_data[i * 4 + 3] - prior_box_data[i * 4 + 1] + 1;
T prior_box_center_x = prior_box_data[i * 4] + prior_box_width / 2;
T prior_box_center_y = prior_box_data[i * 4 + 1] + prior_box_height / 2;
int offset = i * class_num * 4 + j * 4;
T dw = prior_box_var_data[2] * target_box_data[offset + 2];
T dh = prior_box_var_data[3] * target_box_data[offset + 3];
if (dw > box_clip) {
dw = box_clip;
}
if (dh > box_clip) {
dh = box_clip;
}
T target_box_center_x = 0, target_box_center_y = 0;
T target_box_width = 0, target_box_height = 0;
target_box_center_x =
prior_box_var_data[0] * target_box_data[offset] * prior_box_width +
prior_box_center_x;
target_box_center_y =
prior_box_var_data[1] * target_box_data[offset + 1] * prior_box_height +
prior_box_center_y;
target_box_width = expf(dw) * prior_box_width;
target_box_height = expf(dh) * prior_box_height;
output_box_data[offset] = target_box_center_x - target_box_width / 2;
output_box_data[offset + 1] = target_box_center_y - target_box_height / 2;
output_box_data[offset + 2] =
target_box_center_x + target_box_width / 2 - 1;
output_box_data[offset + 3] =
target_box_center_y + target_box_height / 2 - 1;
}
}
template <typename T>
__global__ void AssignBoxKernel(const T* prior_box_data,
const T* box_score_data, T* output_box_data,
const int roi_num, const int class_num,
T* output_assign_box_data) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < roi_num) {
int i = idx;
T max_score = -1;
int max_j = -1;
for (int j = 0; j < class_num; ++j) {
T score = box_score_data[i * class_num + j];
if (score > max_score && j > 0) {
max_score = score;
max_j = j;
}
}
if (max_j > 0) {
for (int pno = 0; pno < 4; pno++) {
output_assign_box_data[i * 4 + pno] =
output_box_data[i * class_num * 4 + max_j * 4 + pno];
}
} else {
for (int pno = 0; pno < 4; pno++) {
output_assign_box_data[i * 4 + pno] = prior_box_data[i * 4 + pno];
}
}
}
}
template <typename DeviceContext, typename T>
class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
auto* prior_box = context.Input<framework::LoDTensor>("PriorBox");
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* box_score = context.Input<framework::LoDTensor>("BoxScore");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
auto* output_assign_box =
context.Output<framework::Tensor>("OutputAssignBox");
auto roi_num = target_box->dims()[0];
auto class_num = box_score->dims()[1];
auto* target_box_data = target_box->data<T>();
auto* prior_box_data = prior_box->data<T>();
auto* prior_box_var_data = prior_box_var->data<T>();
auto* box_score_data = box_score->data<T>();
output_box->mutable_data<T>({roi_num, class_num * 4}, context.GetPlace());
output_assign_box->mutable_data<T>({roi_num, 4}, context.GetPlace());
T* output_box_data = output_box->data<T>();
T* output_assign_box_data = output_assign_box->data<T>();
int block = 512;
int grid = (roi_num * class_num + block - 1) / block;
auto& device_ctx = context.cuda_device_context();
const T box_clip = context.Attr<T>("box_clip");
DecodeBoxKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, roi_num, class_num,
box_clip, output_box_data);
context.device_context().Wait();
int assign_grid = (roi_num + block - 1) / block;
AssignBoxKernel<T><<<assign_grid, block, 0, device_ctx.stream()>>>(
prior_box_data, box_score_data, output_box_data, roi_num, class_num,
output_assign_box_data);
context.device_context().Wait();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
box_decoder_and_assign,
ops::BoxDecoderAndAssignCUDAKernel<paddle::platform::CUDADeviceContext,
float>,
ops::BoxDecoderAndAssignCUDAKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class BoxDecoderAndAssignKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* prior_box = context.Input<framework::LoDTensor>("PriorBox");
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* box_score = context.Input<framework::LoDTensor>("BoxScore");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
auto* output_assign_box =
context.Output<framework::Tensor>("OutputAssignBox");
int roi_num = target_box->dims()[0];
int class_num = box_score->dims()[1];
auto* target_box_data = target_box->data<T>();
auto* prior_box_data = prior_box->data<T>();
auto* prior_box_var_data = prior_box_var->data<T>();
auto* box_score_data = box_score->data<T>();
output_box->mutable_data<T>({roi_num, class_num * 4}, context.GetPlace());
output_assign_box->mutable_data<T>({roi_num, 4}, context.GetPlace());
T* output_box_data = output_box->data<T>();
T* output_assign_box_data = output_assign_box->data<T>();
const T bbox_clip = context.Attr<T>("box_clip");
for (int i = 0; i < roi_num; ++i) {
T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1;
T prior_box_height =
prior_box_data[i * 4 + 3] - prior_box_data[i * 4 + 1] + 1;
T prior_box_center_x = prior_box_data[i * 4] + prior_box_width / 2;
T prior_box_center_y = prior_box_data[i * 4 + 1] + prior_box_height / 2;
for (int j = 0; j < class_num; ++j) {
int64_t offset = i * class_num * 4 + j * 4;
T dw = std::min(prior_box_var_data[2] * target_box_data[offset + 2],
bbox_clip);
T dh = std::min(prior_box_var_data[3] * target_box_data[offset + 3],
bbox_clip);
T target_box_center_x = 0, target_box_center_y = 0;
T target_box_width = 0, target_box_height = 0;
target_box_center_x =
prior_box_var_data[0] * target_box_data[offset] * prior_box_width +
prior_box_center_x;
target_box_center_y = prior_box_var_data[1] *
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
target_box_width = std::exp(dw) * prior_box_width;
target_box_height = std::exp(dh) * prior_box_height;
output_box_data[offset] = target_box_center_x - target_box_width / 2;
output_box_data[offset + 1] =
target_box_center_y - target_box_height / 2;
output_box_data[offset + 2] =
target_box_center_x + target_box_width / 2 - 1;
output_box_data[offset + 3] =
target_box_center_y + target_box_height / 2 - 1;
}
T max_score = -1;
int max_j = -1;
for (int j = 0; j < class_num; ++j) {
T score = box_score_data[i * class_num + j];
if (score > max_score && j > 0) {
max_score = score;
max_j = j;
}
}
if (max_j > 0) {
for (int pno = 0; pno < 4; pno++) {
output_assign_box_data[i * 4 + pno] =
output_box_data[i * class_num * 4 + max_j * 4 + pno];
}
} else {
for (int pno = 0; pno < 4; pno++) {
output_assign_box_data[i * 4 + pno] = prior_box_data[i * 4 + pno];
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,7 @@ __all__ = [
'yolov3_loss',
'box_clip',
'multiclass_nms',
'box_decoder_and_assign',
]
......@@ -2221,3 +2222,53 @@ def multiclass_nms(bboxes,
output.stop_gradient = True
return output
@templatedoc()
def box_decoder_and_assign(prior_box, prior_box_var, target_box, box_score,
box_clip):
"""
${comment}
Args:
prior_box(${prior_box_type}): ${prior_box_comment}
prior_box_var(${prior_box_var_type}): ${prior_box_var_comment}
target_box(${target_box_type}): ${target_box_comment}
box_score(${box_score_type}): ${box_score_comment}
Returns:
output_box(${output_box_type}): ${output_box_comment}
output_assign_box(${output_assign_box_type}): ${output_assign_box_comment}
Examples:
.. code-block:: python
pb = fluid.layers.data(name='prior_box', shape=[20, 4],
dtype='float32')
pbv = fluid.layers.data(name='prior_box_var', shape=[1, 4],
dtype='float32')
loc = fluid.layers.data(name='target_box', shape=[20, 4*81],
dtype='float32')
scores = fluid.layers.data(name='scores', shape=[20, 81],
dtype='float32')
output_box, output_assign_box = fluid.layers.box_decoder_and_assign(pb, pbv, loc, scores, 4.135)
"""
helper = LayerHelper("box_decoder_and_assign", **locals())
output_box = helper.create_variable_for_type_inference(
dtype=prior_box.dtype)
output_assign_box = helper.create_variable_for_type_inference(
dtype=prior_box.dtype)
helper.append_op(
type="box_decoder_and_assign",
inputs={
"PriorBox": prior_box,
"PriorBoxVar": prior_box_var,
"TargetBox": target_box,
"BoxScore": box_score
},
attrs={"box_clip": box_clip},
outputs={
"OutputBox": output_box,
"OutputAssignBox": output_assign_box
})
return output_box, output_assign_box
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
import math
from op_test import OpTest
def box_decoder_and_assign(deltas, weights, boxes, box_score, box_clip):
boxes = boxes.astype(deltas.dtype, copy=False)
widths = boxes[:, 2] - boxes[:, 0] + 1.0
heights = boxes[:, 3] - boxes[:, 1] + 1.0
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx, wy, ww, wh = weights
dx = deltas[:, 0::4] * wx
dy = deltas[:, 1::4] * wy
dw = deltas[:, 2::4] * ww
dh = deltas[:, 3::4] * wh
# Prevent sending too large values into np.exp()
dw = np.minimum(dw, box_clip)
dh = np.minimum(dh, box_clip)
pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]
pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
pred_w = np.exp(dw) * widths[:, np.newaxis]
pred_h = np.exp(dh) * heights[:, np.newaxis]
pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype)
# x1
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
# y1
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
# x2 (note: "- 1" is correct; don't be fooled by the asymmetry)
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
# y2 (note: "- 1" is correct; don't be fooled by the asymmetry)
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1
output_assign_box = []
for ino in range(len(pred_boxes)):
rank = np.argsort(-box_score[ino])
maxidx = rank[0]
if maxidx == 0:
maxidx = rank[1]
beg_pos = maxidx * 4
end_pos = maxidx * 4 + 4
output_assign_box.append(pred_boxes[ino, beg_pos:end_pos])
output_assign_box = np.array(output_assign_box)
return pred_boxes, output_assign_box
class TestBoxDecoderAndAssignOpWithLoD(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_decoder_and_assign"
lod = [[4, 8, 8]]
num_classes = 10
prior_box = np.random.random((20, 4)).astype('float32')
prior_box_var = np.array([0.1, 0.1, 0.2, 0.2], dtype=np.float32)
target_box = np.random.random((20, 4 * num_classes)).astype('float32')
box_score = np.random.random((20, num_classes)).astype('float32')
box_clip = 4.135
output_box, output_assign_box = box_decoder_and_assign(
target_box, prior_box_var, prior_box, box_score, box_clip)
self.inputs = {
'PriorBox': (prior_box, lod),
'PriorBoxVar': prior_box_var,
'TargetBox': (target_box, lod),
'BoxScore': (box_score, lod),
}
self.attrs = {'box_clip': box_clip}
self.outputs = {
'OutputBox': output_box,
'OutputAssignBox': output_assign_box
}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册