diff --git a/paddle/operators/box_coder_op.cc b/paddle/operators/box_coder_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0cb20a4182e786a0183ab586d8c97c2e30682760 --- /dev/null +++ b/paddle/operators/box_coder_op.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/box_coder_op.h" + +namespace paddle { +namespace operators { + +class BoxCoderOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("PriorBox"), + "Input(PriorBox) of BoxCoderOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("PriorBoxVar"), + "Input(PriorBoxVar) of BoxCoderOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("PriorBox"), + "Input(TargetBox) of BoxCoderOp should not be null."); + + auto prior_box_dims = ctx->GetInputDim("PriorBox"); + auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); + auto target_box_dims = ctx->GetInputDim("TargetBox"); + + PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2UL, + "The shape of PriorBox is [N, 4]"); + PADDLE_ENFORCE_EQ(prior_box_dims[1], 4UL, + "The shape of PriorBox is [N, 4]"); + PADDLE_ENFORCE_EQ(prior_box_var_dims.size(), 2UL, + "The shape of PriorBoxVar is [N, 4]"); + PADDLE_ENFORCE_EQ(prior_box_var_dims[1], 4UL, + "The shape of PriorBoxVar is [N, 4]"); + PADDLE_ENFORCE_EQ(target_box_dims.size(), 2UL, + "The shape of TargetBox is [M, 4]"); + PADDLE_ENFORCE_EQ(target_box_dims[1], 4UL, + "The shape of TargetBox is [M, 4]"); + + GetBoxCodeType(ctx->Attrs().Get("code_type")); + + ctx->SetOutputDim("OutputBox", framework::make_ddim({target_box_dims[0], + target_box_dims[1]})); + } +}; + +class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "PriorBox", + "(Tensor, default Tensor) " + "Box list PriorBox is a 2-D Tensor with shape [M, 4] holds N boxes, " + "each box is represented as [xmin, ymin, xmax, ymax], " + "[xmin, ymin] is the left top coordinate of the anchor box, " + "if the input is image feature map, they are close to the origin " + "of the coordinate system. [xmax, ymax] is the right bottom " + "coordinate of the anchor box."); + AddInput("PriorBoxVar", + "(Tensor, default Tensor) " + "PriorBoxVar is a 2-D Tensor with shape [M, 4] holds N group " + "of variance."); + AddInput( + "TargetBox", + "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " + "[N, 4], each box is represented as [xmin, ymin, xmax, ymax], " + "[xmin, ymin] is the left top coordinate of the box if the input " + "is image feature map, they are close to the origin of the coordinate " + "system. [xmax, ymax] is the right bottom coordinate of the box. " + "This tensor can contain LoD information to represent a batch " + "of inputs. One instance of this batch can contain different " + "numbers of entities."); + AddAttr("code_type", + "(string, default encode_center_size) " + "the code type used with the target box") + .SetDefault("encode_center_size") + .InEnum({"encode_center_size", "decode_center_size"}); + AddOutput( + "OutputBox", + "(Tensor, default Tensor)" + "(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] " + "representing the result of N target boxes encoded/decoded with " + "M Prior boxes and variances."); + + AddComment(R"DOC( +Bounding Box Coder Operator. +Encode/Decode the priorbox information with the target bounding box. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker); +REGISTER_OP_CPU_KERNEL(box_coder, ops::BoxCoderKernel, + ops::BoxCoderKernel); diff --git a/paddle/operators/box_coder_op.cu b/paddle/operators/box_coder_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..4055ded1f8b61d21b2b9d39b5ed41c40c221d3fb --- /dev/null +++ b/paddle/operators/box_coder_op.cu @@ -0,0 +1,145 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/box_coder_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void EncodeCenterSizeKernel(const T* prior_box_data, + const T* prior_box_var_data, + const T* target_box_data, int row, + int col, T* output) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < row * col) { + const int row_idx = idx / col; + const int col_idx = idx % col; + T prior_box_width = + prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4]; + T prior_box_height = + prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1]; + T prior_box_center_x = + (prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2; + T prior_box_center_y = + (prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2; + + T target_box_center_x = + (target_box_data[row_idx * 4 + 2] + target_box_data[row_idx * 4]) / 2; + T target_box_center_y = + (target_box_data[row_idx * 4 + 3] + target_box_data[row_idx * 4 + 1]) / + 2; + T target_box_width = + target_box_data[row_idx * 4 + 2] - target_box_data[row_idx * 4]; + T target_box_height = + target_box_data[row_idx * 4 + 3] - target_box_data[row_idx * 4 + 1]; + + output[idx * 4] = (target_box_center_x - prior_box_center_x) / + prior_box_width / prior_box_var_data[col_idx * 4]; + output[idx * 4 + 1] = (target_box_center_y - prior_box_center_y) / + prior_box_height / + prior_box_var_data[col_idx * 4 + 1]; + output[idx * 4 + 2] = log(fabs(target_box_width / prior_box_width)) / + prior_box_var_data[col_idx * 4 + 2]; + output[idx * 4 + 3] = log(fabs(target_box_height / prior_box_height)) / + prior_box_var_data[col_idx * 4 + 3]; + } +} + +template +__global__ void DecodeCenterSizeKernel(const T* prior_box_data, + const T* prior_box_var_data, + const T* target_box_data, int row, + int col, T* output) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < row * col) { + const int row_idx = idx / col; + const int col_idx = idx % col; + T prior_box_width = + prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4]; + T prior_box_height = + prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1]; + T prior_box_center_x = + (prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2; + T prior_box_center_y = + (prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2; + + T target_box_width = exp(prior_box_var_data[col_idx * 4 + 2] * + target_box_data[row_idx * 4 + 2]) * + prior_box_width; + T target_box_height = exp(prior_box_var_data[col_idx * 4 + 3] * + target_box_data[row_idx * 4 + 3]) * + prior_box_height; + T target_box_center_x = prior_box_var_data[col_idx * 4] * + target_box_data[row_idx * 4] * prior_box_width + + prior_box_center_x; + T target_box_center_y = prior_box_var_data[col_idx * 4 + 1] * + target_box_data[row_idx * 4 + 1] * + prior_box_height + + prior_box_center_y; + + output[idx * 4] = target_box_center_x - target_box_width / 2; + output[idx * 4 + 1] = target_box_center_y - target_box_height / 2; + output[idx * 4 + 2] = target_box_center_x + target_box_width / 2; + output[idx * 4 + 3] = target_box_center_y + target_box_height / 2; + } +} + +template +class BoxCoderCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "This kernel only runs on GPU device."); + auto* prior_box = context.Input("PriorBox"); + auto* prior_box_var = context.Input("PriorBoxVar"); + auto* target_box = context.Input("TargetBox"); + auto* output_box = context.Output("OutputBox"); + + if (target_box->lod().size()) { + PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, + "Only support 1 level of LoD."); + } + auto row = target_box->dims()[0]; + auto col = prior_box->dims()[0]; + int block = 512; + int grid = (row * col + block - 1) / block; + auto& device_ctx = context.cuda_device_context(); + + const T* prior_box_data = prior_box->data(); + const T* prior_box_var_data = prior_box_var->data(); + const T* target_box_data = target_box->data(); + + output_box->mutable_data({row, col, 4}, context.GetPlace()); + T* output = output_box->data(); + + auto code_type = GetBoxCodeType(context.Attr("code_type")); + if (code_type == BoxCodeType::kEncodeCenterSize) { + EncodeCenterSizeKernel<<>>( + prior_box_data, prior_box_var_data, target_box_data, row, col, + output); + } else if (code_type == BoxCodeType::kDecodeCenterSize) { + DecodeCenterSizeKernel<<>>( + prior_box_data, prior_box_var_data, target_box_data, row, col, + output); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(box_coder, ops::BoxCoderCUDAKernel, + ops::BoxCoderCUDAKernel); diff --git a/paddle/operators/box_coder_op.h b/paddle/operators/box_coder_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3865da40c33d57971dc4a4dfc817bb23f466868a --- /dev/null +++ b/paddle/operators/box_coder_op.h @@ -0,0 +1,163 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 }; + +inline BoxCodeType GetBoxCodeType(const std::string& type) { + if (type == "encode_center_size") { + return BoxCodeType::kEncodeCenterSize; + } else if (type == "decode_center_size") { + return BoxCodeType::kDecodeCenterSize; + } + PADDLE_THROW("Not support type %s.", type); +} + +template +class BoxCoderKernel : public framework::OpKernel { + public: + void EncodeCenterSize(const Tensor& target_box, const Tensor& prior_box, + const Tensor& prior_box_var, T* output) const { + PADDLE_ENFORCE_EQ(target_box.dims().size(), 2, + "The rank of target_box must be 2."); + PADDLE_ENFORCE_EQ(prior_box.dims().size(), 2, + "The rank of prior_box must be 2."); + PADDLE_ENFORCE_EQ(prior_box_var.dims().size(), 2, + "The rank of prior_box_var must be 2."); + PADDLE_ENFORCE_EQ(prior_box.dims()[0], prior_box_var.dims()[0], + "The dims of prior_box must equal to prior_box_var."); + + int64_t row = target_box.dims()[0]; + int64_t col = prior_box.dims()[0]; + auto* target_box_data = target_box.data(); + auto* prior_box_data = prior_box.data(); + auto* prior_box_var_data = prior_box_var.data(); + + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + T prior_box_width = prior_box_data[j * 4 + 2] - prior_box_data[j * 4]; + T prior_box_height = + prior_box_data[j * 4 + 3] - prior_box_data[j * 4 + 1]; + T prior_box_center_x = + (prior_box_data[j * 4 + 2] + prior_box_data[j * 4]) / 2; + T prior_box_center_y = + (prior_box_data[j * 4 + 3] + prior_box_data[j * 4 + 1]) / 2; + + T target_box_center_x = + (target_box_data[i * 4 + 2] + target_box_data[i * 4]) / 2; + T target_box_center_y = + (target_box_data[i * 4 + 3] + target_box_data[i * 4 + 1]) / 2; + T target_box_width = + target_box_data[i * 4 + 2] - target_box_data[i * 4]; + T target_box_height = + target_box_data[i * 4 + 3] - target_box_data[i * 4 + 1]; + + size_t offset = i * col * 4 + j * 4; + output[offset] = (target_box_center_x - prior_box_center_x) / + prior_box_width / prior_box_var_data[j * 4]; + output[offset + 1] = (target_box_center_y - prior_box_center_y) / + prior_box_height / prior_box_var_data[j * 4 + 1]; + output[offset + 2] = + std::log(std::fabs(target_box_width / prior_box_width)) / + prior_box_var_data[j * 4 + 2]; + output[offset + 3] = + std::log(std::fabs(target_box_height / prior_box_height)) / + prior_box_var_data[j * 4 + 3]; + } + } + } + void DecodeCenterSize(const Tensor& target_box, const Tensor& prior_box, + const Tensor& prior_box_var, T* output) const { + PADDLE_ENFORCE_EQ(target_box.dims().size(), 2, + "The rank of target_box must be 2."); + PADDLE_ENFORCE_EQ(prior_box.dims().size(), 2, + "The rank of prior_box must be 2."); + PADDLE_ENFORCE_EQ(prior_box_var.dims().size(), 2, + "The rank of prior_box_var must be 2."); + PADDLE_ENFORCE_EQ(prior_box.dims()[0], prior_box_var.dims()[0], + "The dims of prior_box must equal to prior_box_var."); + + int64_t row = target_box.dims()[0]; + int64_t col = prior_box.dims()[0]; + + auto* target_box_data = target_box.data(); + auto* prior_box_data = prior_box.data(); + auto* prior_box_var_data = prior_box_var.data(); + + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + T prior_box_width = prior_box_data[j * 4 + 2] - prior_box_data[j * 4]; + T prior_box_height = + prior_box_data[j * 4 + 3] - prior_box_data[j * 4 + 1]; + T prior_box_center_x = + (prior_box_data[j * 4 + 2] + prior_box_data[j * 4]) / 2; + T prior_box_center_y = + (prior_box_data[j * 4 + 3] + prior_box_data[j * 4 + 1]) / 2; + + T target_box_center_x = prior_box_var_data[j * 4] * + target_box_data[i * 4] * prior_box_width + + prior_box_center_x; + T target_box_center_y = prior_box_var_data[j * 4 + 1] * + target_box_data[i * 4 + 1] * + prior_box_height + + prior_box_center_y; + T target_box_width = std::exp(prior_box_var_data[j * 4 + 2] * + target_box_data[i * 4 + 2]) * + prior_box_width; + T target_box_height = std::exp(prior_box_var_data[j * 4 + 3] * + target_box_data[i * 4 + 3]) * + prior_box_height; + + size_t offset = i * col * 4 + j * 4; + output[offset] = target_box_center_x - target_box_width / 2; + output[offset + 1] = target_box_center_y - target_box_height / 2; + output[offset + 2] = target_box_center_x + target_box_width / 2; + output[offset + 3] = target_box_center_y + target_box_height / 2; + } + } + } + + void Compute(const framework::ExecutionContext& context) const override { + auto* prior_box = context.Input("PriorBox"); + auto* prior_box_var = context.Input("PriorBoxVar"); + auto* target_box = context.Input("TargetBox"); + auto* output_box = context.Output("OutputBox"); + + if (target_box->lod().size()) { + PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL, + "Only support 1 level of LoD."); + } + auto row = target_box->dims()[0]; + auto col = prior_box->dims()[0]; + + output_box->mutable_data({row, col, 4}, context.GetPlace()); + + auto code_type = GetBoxCodeType(context.Attr("code_type")); + T* output = output_box->data(); + if (code_type == BoxCodeType::kEncodeCenterSize) { + EncodeCenterSize(*target_box, *prior_box, *prior_box_var, output); + } else if (code_type == BoxCodeType::kDecodeCenterSize) { + DecodeCenterSize(*target_box, *prior_box, *prior_box_var, output); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_box_coder_op.py b/python/paddle/v2/fluid/tests/test_box_coder_op.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf5da01ced723e5b7bd3e6cc8666c1c36053efe --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_box_coder_op.py @@ -0,0 +1,117 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): + prior_box_x = (prior_box[:, 2] + prior_box[:, 0]) / 2 + prior_box_y = (prior_box[:, 3] + prior_box[:, 1]) / 2 + prior_box_width = (prior_box[:, 2] - prior_box[:, 0]) + prior_box_height = (prior_box[:, 3] - prior_box[:, 1]) + + if (code_type == "EncodeCenterSize"): + target_box_x = (target_box[:, 2] + target_box[:, 0]) / 2 + target_box_y = (target_box[:, 3] + target_box[:, 1]) / 2 + target_box_width = (target_box[:, 2] - target_box[:, 0]) + target_box_height = (target_box[:, 3] - target_box[:, 1]) + + for i in range(target_box.shape[0]): + output_box[i,:,0] = (target_box_x[i] - prior_box_x) / prior_box_width / \ + prior_box_var[:,0] + output_box[i,:,1] = (target_box_y[i] - prior_box_y) / prior_box_height / \ + prior_box_var[:,1] + output_box[i,:,2] = np.log(np.fabs(target_box_width[i] / prior_box_width)) / \ + prior_box_var[:,2] + output_box[i,:,3] = np.log(np.fabs(target_box_height[i] / prior_box_height)) / \ + prior_box_var[:,3] + + elif (code_type == "DecodeCenterSize"): + for i in range(target_box.shape[0]): + target_box_x = prior_box_var[:,0] * target_box[i][0] * \ + prior_box_width[:] + prior_box_x[:] + target_box_y = prior_box_var[:,1] * target_box[i][1] * \ + prior_box_height[:] + prior_box_y[:] + target_box_width = np.exp(prior_box_var[:,2] * target_box[i][2]) * \ + prior_box_width[:] + target_box_height = np.exp(prior_box_var[:,3] * target_box[i][3]) * \ + prior_box_height[:] + output_box[i, :, 0] = target_box_x - target_box_width / 2 + output_box[i, :, 1] = target_box_y - target_box_height / 2 + output_box[i, :, 2] = target_box_x + target_box_width / 2 + output_box[i, :, 3] = target_box_y + target_box_height / 2 + + +def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type): + n = target_box.shape[0] + m = prior_box.shape[0] + output_box = np.zeros((n, m, 4), dtype=np.float32) + for i in range(len(lod) - 1): + box_coder(target_box[lod[i]:lod[i + 1], :], prior_box, prior_box_var, + output_box[lod[i]:lod[i + 1], :, :], code_type) + return output_box + + +class TestBoxCoderOp(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_coder" + lod = [[0, 20]] + prior_box = np.random.random((10, 4)).astype('float32') + prior_box_var = np.random.random((10, 4)).astype('float32') + target_box = np.random.random((20, 4)).astype('float32') + code_type = "DecodeCenterSize" + output_box = batch_box_coder(prior_box, prior_box_var, target_box, + lod[0], code_type) + + self.inputs = { + 'PriorBox': prior_box, + 'PriorBoxVar': prior_box_var, + 'TargetBox': target_box, + } + self.attrs = {'code_type': 'decode_center_size'} + self.outputs = {'OutputBox': output_box} + + +class TestBoxCoderOpWithLoD(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_coder" + lod = [[0, 4, 12, 20]] + prior_box = np.random.random((10, 4)).astype('float32') + prior_box_var = np.random.random((10, 4)).astype('float32') + target_box = np.random.random((20, 4)).astype('float32') + code_type = "EncodeCenterSize" + output_box = batch_box_coder(prior_box, prior_box_var, target_box, + lod[0], code_type) + + self.inputs = { + 'PriorBox': prior_box, + 'PriorBoxVar': prior_box_var, + 'TargetBox': (target_box, lod), + } + self.attrs = {'code_type': 'encode_center_size'} + self.outputs = {'OutputBox': output_box} + + +if __name__ == '__main__': + unittest.main()