提交 666c94e3 编写于 作者: Y Yuan Gao 提交者: qingqing01

Add default prior box var for box_coder_op (#11164)

* add normalize switch to box_coder_op

* add default prior box var

* update according to the review
上级 c12c041e
......@@ -22,21 +22,21 @@ class BoxCoderOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
"Input(PriorBox) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PriorBoxVar"),
"Input(PriorBoxVar) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("TargetBox"),
"Input(TargetBox) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutputBox"),
"Output(OutputBox) of BoxCoderOp should not be null.");
auto prior_box_dims = ctx->GetInputDim("PriorBox");
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
auto target_box_dims = ctx->GetInputDim("TargetBox");
PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2,
"The rank of Input of PriorBoxVar must be 2");
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]");
PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims);
if (ctx->HasInput("PriorBoxVar")) {
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims);
}
auto code_type = GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
if (code_type == BoxCodeType::kEncodeCenterSize) {
......@@ -71,9 +71,11 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"of the coordinate system. [xmax, ymax] is the right bottom "
"coordinate of the anchor box.");
AddInput("PriorBoxVar",
"(Tensor, default Tensor<float>) "
"(Tensor, default Tensor<float>, optional) "
"PriorBoxVar is a 2-D Tensor with shape [M, 4] holds M group "
"of variance.");
"of variance. PriorBoxVar will set all elements to 1 by "
"default.")
.AsDispensable();
AddInput(
"TargetBox",
"(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape "
......@@ -131,5 +133,6 @@ width and height.
namespace ops = paddle::operators;
REGISTER_OPERATOR(box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(box_coder, ops::BoxCoderKernel<float>,
ops::BoxCoderKernel<double>);
REGISTER_OP_CPU_KERNEL(
box_coder, ops::BoxCoderKernel<paddle::platform::CPUDeviceContext, float>,
ops::BoxCoderKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -48,15 +48,18 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
target_box_data[row_idx * len + 1] +
(normalized == false);
output[idx * len] = (target_box_center_x - prior_box_center_x) /
prior_box_width / prior_box_var_data[col_idx * len];
output[idx * len + 1] = (target_box_center_y - prior_box_center_y) /
prior_box_height /
prior_box_var_data[col_idx * len + 1];
output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)) /
prior_box_var_data[col_idx * len + 2];
output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)) /
prior_box_var_data[col_idx * len + 3];
output[idx * len] =
(target_box_center_x - prior_box_center_x) / prior_box_width;
output[idx * len + 1] =
(target_box_center_y - prior_box_center_y) / prior_box_height;
output[idx * len + 2] = log(fabs(target_box_width / prior_box_width));
output[idx * len + 3] = log(fabs(target_box_height / prior_box_height));
if (prior_box_var_data) {
output[idx * len] /= prior_box_var_data[col_idx * len];
output[idx * len + 1] /= prior_box_var_data[col_idx * len + 1];
output[idx * len + 2] /= prior_box_var_data[col_idx * len + 2];
output[idx * len + 3] /= prior_box_var_data[col_idx * len + 3];
}
}
}
......@@ -79,20 +82,31 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
T prior_box_center_y = (prior_box_data[col_idx * len + 3] +
prior_box_data[col_idx * len + 1]) /
2;
T target_box_width = exp(prior_box_var_data[col_idx * len + 2] *
T target_box_width, target_box_height;
T target_box_center_x, target_box_center_y;
if (prior_box_var_data) {
target_box_width = exp(prior_box_var_data[col_idx * len + 2] *
target_box_data[idx * len + 2]) *
prior_box_width;
T target_box_height = exp(prior_box_var_data[col_idx * len + 3] *
target_box_height = exp(prior_box_var_data[col_idx * len + 3] *
target_box_data[idx * len + 3]) *
prior_box_height;
T target_box_center_x = prior_box_var_data[col_idx * len] *
target_box_center_x = prior_box_var_data[col_idx * len] *
target_box_data[idx * len] * prior_box_width +
prior_box_center_x;
T target_box_center_y = prior_box_var_data[col_idx * len + 1] *
target_box_center_y = prior_box_var_data[col_idx * len + 1] *
target_box_data[idx * len + 1] *
prior_box_height +
prior_box_center_y;
} else {
target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width;
target_box_height =
exp(target_box_data[idx * len + 3]) * prior_box_height;
target_box_center_x =
target_box_data[idx * len] * prior_box_width + prior_box_center_x;
target_box_center_y = target_box_data[idx * len + 1] * prior_box_height +
prior_box_center_y;
}
output[idx * len] = target_box_center_x - target_box_width / 2;
output[idx * len + 1] = target_box_center_y - target_box_height / 2;
......@@ -103,7 +117,7 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
}
}
template <typename T>
template <typename DeviceContext, typename T>
class BoxCoderCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -114,6 +128,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
const T* prior_box_data = prior_box->data<T>();
const T* target_box_data = target_box->data<T>();
const T* prior_box_var_data = nullptr;
if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
"Only support 1 level of LoD.");
......@@ -125,10 +144,6 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
int grid = (row * col + block - 1) / block;
auto& device_ctx = context.cuda_device_context();
const T* prior_box_data = prior_box->data<T>();
const T* prior_box_var_data = prior_box_var->data<T>();
const T* target_box_data = target_box->data<T>();
output_box->mutable_data<T>({row, col, len}, context.GetPlace());
T* output = output_box->data<T>();
......@@ -150,5 +165,7 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(box_coder, ops::BoxCoderCUDAKernel<float>,
ops::BoxCoderCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
box_coder,
ops::BoxCoderCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::BoxCoderCUDAKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -28,19 +28,20 @@ inline BoxCodeType GetBoxCodeType(const std::string& type) {
PADDLE_THROW("Not support type %s.", type);
}
template <typename T>
template <typename DeviceContext, typename T>
class BoxCoderKernel : public framework::OpKernel<T> {
public:
void EncodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var,
void EncodeCenterSize(const framework::Tensor* target_box,
const framework::Tensor* prior_box,
const framework::Tensor* prior_box_var,
const bool normalized, T* output) const {
int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1];
auto* target_box_data = target_box.data<T>();
auto* prior_box_data = prior_box.data<T>();
auto* prior_box_var_data = prior_box_var.data<T>();
int64_t row = target_box->dims()[0];
int64_t col = prior_box->dims()[0];
int64_t len = prior_box->dims()[1];
auto* target_box_data = target_box->data<T>();
auto* prior_box_data = prior_box->data<T>();
const T* prior_box_var_data = nullptr;
if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
......@@ -65,30 +66,35 @@ class BoxCoderKernel : public framework::OpKernel<T> {
(normalized == false);
size_t offset = i * col * len + j * len;
output[offset] = (target_box_center_x - prior_box_center_x) /
prior_box_width / prior_box_var_data[j * len];
output[offset + 1] = (target_box_center_y - prior_box_center_y) /
prior_box_height / prior_box_var_data[j * len + 1];
output[offset] =
(target_box_center_x - prior_box_center_x) / prior_box_width;
output[offset + 1] =
(target_box_center_y - prior_box_center_y) / prior_box_height;
output[offset + 2] =
std::log(std::fabs(target_box_width / prior_box_width)) /
prior_box_var_data[j * len + 2];
std::log(std::fabs(target_box_width / prior_box_width));
output[offset + 3] =
std::log(std::fabs(target_box_height / prior_box_height)) /
prior_box_var_data[j * len + 3];
std::log(std::fabs(target_box_height / prior_box_height));
if (prior_box_var) {
output[offset] /= prior_box_var_data[j * len];
output[offset + 1] /= prior_box_var_data[j * len + 1];
output[offset + 2] /= prior_box_var_data[j * len + 2];
output[offset + 3] /= prior_box_var_data[j * len + 3];
}
}
}
}
void DecodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var,
void DecodeCenterSize(const framework::Tensor* target_box,
const framework::Tensor* prior_box,
const framework::Tensor* prior_box_var,
const bool normalized, T* output) const {
int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1];
int64_t row = target_box->dims()[0];
int64_t col = prior_box->dims()[0];
int64_t len = prior_box->dims()[1];
auto* target_box_data = target_box.data<T>();
auto* prior_box_data = prior_box.data<T>();
auto* prior_box_var_data = prior_box_var.data<T>();
auto* target_box_data = target_box->data<T>();
auto* prior_box_data = prior_box->data<T>();
const T* prior_box_var_data = nullptr;
if (prior_box_var) prior_box_var_data = prior_box_var->data<T>();
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
......@@ -103,19 +109,32 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T prior_box_center_y =
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
T target_box_center_x = prior_box_var_data[j * len] *
T target_box_center_x = 0, target_box_center_y = 0;
T target_box_width = 0, target_box_height = 0;
if (prior_box_var) {
target_box_center_x = prior_box_var_data[j * len] *
target_box_data[offset] * prior_box_width +
prior_box_center_x;
T target_box_center_y = prior_box_var_data[j * len + 1] *
target_box_center_y = prior_box_var_data[j * len + 1] *
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
T target_box_width = std::exp(prior_box_var_data[j * len + 2] *
target_box_width = std::exp(prior_box_var_data[j * len + 2] *
target_box_data[offset + 2]) *
prior_box_width;
T target_box_height = std::exp(prior_box_var_data[j * len + 3] *
target_box_height = std::exp(prior_box_var_data[j * len + 3] *
target_box_data[offset + 3]) *
prior_box_height;
} else {
target_box_center_x =
target_box_data[offset] * prior_box_width + prior_box_center_x;
target_box_center_y = target_box_data[offset + 1] * prior_box_height +
prior_box_center_y;
target_box_width =
std::exp(target_box_data[offset + 2]) * prior_box_width;
target_box_height =
std::exp(target_box_data[offset + 3]) * prior_box_height;
}
output[offset] = target_box_center_x - target_box_width / 2;
output[offset + 1] = target_box_center_y - target_box_height / 2;
......@@ -147,10 +166,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
bool normalized = context.Attr<bool>("box_normalized");
T* output = output_box->data<T>();
if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSize(*target_box, *prior_box, *prior_box_var, normalized,
EncodeCenterSize(target_box, prior_box, prior_box_var, normalized,
output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSize(*target_box, *prior_box, *prior_box_var, normalized,
DecodeCenterSize(target_box, prior_box, prior_box_var, normalized,
output);
}
}
......
......@@ -120,6 +120,32 @@ class TestBoxCoderOp(OpTest):
self.outputs = {'OutputBox': output_box}
class TestBoxCoderOpWithoutBoxVar(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_coder"
lod = [[0, 1, 2, 3, 4, 5]]
prior_box = np.random.random((10, 4)).astype('float32')
prior_box_var = np.ones((10, 4)).astype('float32')
target_box = np.random.random((5, 10, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type, box_normalized)
self.inputs = {
'PriorBox': prior_box,
'TargetBox': target_box,
}
self.attrs = {
'code_type': 'decode_center_size',
'box_normalized': False
}
self.outputs = {'OutputBox': output_box}
class TestBoxCoderOpWithLoD(OpTest):
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册