提交 a39240c3 编写于 作者: J jerrywgz

add attr variance for box coder, test=develop

上级 6928f831
......@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include <vector>
namespace paddle {
namespace operators {
......@@ -134,6 +135,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"when code type is decode_center_size")
.SetDefault(0)
.InEnum({0, 1});
AddAttr<std::vector<float>>(
"variance",
"(vector<float>, default {}),"
"variance of prior box with shape [4]. PriorBoxVar and variance can"
"not be provided at the same time.")
.SetDefault(std::vector<float>{});
AddOutput("OutputBox",
"(LoDTensor or Tensor) "
"When code_type is 'encode_center_size', the output tensor of "
......
......@@ -9,6 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -16,12 +18,11 @@ namespace paddle {
namespace operators {
template <typename T>
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data, const int row,
const int col, const int len,
const bool normalized,
const T prior_box_var_size, T* output) {
__global__ void EncodeCenterSizeKernel(
const T* prior_box_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const int col, const int len,
const bool normalized, const T prior_box_var_size, const float* variance,
const int var_size, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) {
const int row_idx = idx / col;
......@@ -62,18 +63,20 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1];
output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2];
output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3];
} else if (var_size == 4) {
for (int k = 0; k < 4; ++k) {
output[idx * len + k] /= static_cast<T>(variance[k]);
}
}
}
}
template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data, const int row,
const int col, const int len,
const bool normalized,
const T prior_box_var_size,
const int axis, T* output) {
__global__ void DecodeCenterSizeKernel(
const T* prior_box_data, const T* prior_box_var_data,
const T* target_box_data, const int row, const int col, const int len,
const bool normalized, const T prior_box_var_size, const float* variance,
const int var_size, const int axis, T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
int prior_box_offset = 0;
if (idx < row * col) {
......@@ -110,6 +113,20 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
target_box_data[idx * len + 1] *
prior_box_height +
prior_box_center_y;
} else if (var_size == 4) {
target_box_width =
exp(static_cast<T>(variance[2]) * target_box_data[idx * len + 2]) *
prior_box_width;
target_box_height =
exp(static_cast<T>(variance[3]) * target_box_data[idx * len + 3]) *
prior_box_height;
target_box_center_x = static_cast<T>(variance[0]) *
target_box_data[idx * len] * prior_box_width +
prior_box_center_x;
target_box_center_y = static_cast<T>(variance[1]) *
target_box_data[idx * len + 1] *
prior_box_height +
prior_box_center_y;
} else {
target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width;
target_box_height =
......@@ -139,20 +156,30 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
std::vector<float> variance = context.Attr<std::vector<float>>("variance");
const T* prior_box_data = prior_box->data<T>();
const T* target_box_data = target_box->data<T>();
const T* prior_box_var_data = nullptr;
auto prior_box_var_size = 0;
if (prior_box_var) {
PADDLE_ENFORCE(variance.empty(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time.");
prior_box_var_data = prior_box_var->data<T>();
prior_box_var_size = prior_box_var->dims().size();
}
if (!(variance.empty())) {
PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
"Size of attribute 'variance' should be 4");
}
if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
"Only support 1 level of LoD.");
}
const int var_size = static_cast<T>(variance.size());
thrust::device_vector<float> dev_variance(variance.begin(), variance.end());
const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data());
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
int axis = context.Attr<int>("axis");
......@@ -173,11 +200,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
normalized, prior_box_var_size, output);
normalized, prior_box_var_size, dev_var_data, var_size, output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
normalized, prior_box_var_size, axis, output);
normalized, prior_box_var_size, dev_var_data, var_size, axis, output);
}
}
};
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void EncodeCenterSize(const framework::Tensor* target_box,
const framework::Tensor* prior_box,
const framework::Tensor* prior_box_var,
const bool normalized, T* output) const {
const bool normalized,
const std::vector<float> variance, T* output) const {
int64_t row = target_box->dims()[0];
int64_t col = prior_box->dims()[0];
int64_t len = prior_box->dims()[1];
......@@ -85,6 +87,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
output[offset + 1] /= prior_box_var_data[prior_var_offset + 1];
output[offset + 2] /= prior_box_var_data[prior_var_offset + 2];
output[offset + 3] /= prior_box_var_data[prior_var_offset + 3];
} else if (!(variance.empty())) {
for (int k = 0; k < 4; ++k) {
output[offset + k] /= static_cast<T>(variance[k]);
}
}
}
}
......@@ -93,7 +99,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
const framework::Tensor* prior_box,
const framework::Tensor* prior_box_var,
const bool normalized, const int axis,
T* output) const {
const std::vector<float> variance, T* output) const {
int64_t row = target_box->dims()[0];
int64_t col = target_box->dims()[1];
int64_t len = target_box->dims()[2];
......@@ -149,6 +155,20 @@ class BoxCoderKernel : public framework::OpKernel<T> {
std::exp(prior_box_var_data[prior_var_offset + 3] *
target_box_data[offset + 3]) *
prior_box_height;
} else if (!(variance.empty())) {
target_box_center_x = static_cast<T>(variance[0]) *
target_box_data[offset] * prior_box_width +
prior_box_center_x;
target_box_center_y = static_cast<T>(variance[1]) *
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
target_box_width = std::exp(static_cast<T>(variance[2]) *
target_box_data[offset + 2]) *
prior_box_width;
target_box_height = std::exp(static_cast<T>(variance[3]) *
target_box_data[offset + 3]) *
prior_box_height;
} else {
target_box_center_x =
target_box_data[offset] * prior_box_width + prior_box_center_x;
......@@ -175,11 +195,21 @@ class BoxCoderKernel : public framework::OpKernel<T> {
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
std::vector<float> variance = context.Attr<std::vector<float>>("variance");
const int axis = context.Attr<int>("axis");
if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
"Only support 1 level of LoD.");
}
if (prior_box_var) {
PADDLE_ENFORCE(variance.empty(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time.");
}
if (!(variance.empty())) {
PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
"Size of attribute 'variance' should be 4");
}
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
bool normalized = context.Attr<bool>("box_normalized");
......@@ -195,10 +225,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T* output = output_box->data<T>();
if (code_type == BoxCodeType::kEncodeCenterSize) {
EncodeCenterSize(target_box, prior_box, prior_box_var, normalized,
output);
variance, output);
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis,
output);
variance, output);
}
}
};
......
......@@ -346,18 +346,104 @@ def box_coder(prior_box,
name=None,
axis=0):
"""
${comment}
**Box Coder Layer**
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
.. math::
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = \log(\abs(tw / pw)) / pwv
oh = \log(\abs(th / ph)) / phv
The Decoding schema described below:
.. math::
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = \exp(pwv * tw) * pw + tw / 2
oh = \exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates,
width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote
the priorbox's (anchor) center coordinates, width and height. `pxv`,
`pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`,
`ow`, `oh` denote the encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target
box has shape [N, M, 4], and the shape of prior box can be [N, 4] or
[M, 4]. Then prior box will broadcast to target box along the
assigned axis.
Args:
prior_box(${prior_box_type}): ${prior_box_comment}
prior_box_var(${prior_box_var_type}): ${prior_box_var_comment}
target_box(${target_box_type}): ${target_box_comment}
code_type(${code_type_type}): ${code_type_comment}
box_normalized(${box_normalized_type}): ${box_normalized_comment}
axis(${axis_type}): ${axis_comment}
prior_box(Variable): Box list prior_box is a 2-D Tensor with shape
[M, 4] holds M boxes, each box is represented as
[xmin, ymin, xmax, ymax], [xmin, ymin] is the
left top coordinate of the anchor box, if the
input is image feature map, they are close to
the origin of the coordinate system. [xmax, ymax]
is the right bottom coordinate of the anchor box.
prior_box_var(Variable|list): prior_box_var supports two types of input.
One is variable with shape [M, 4] holds M group.
The other one is list consist of 4 elements
shared by all boxes.
target_box(Variable): This input can be a 2-D LoDTensor with shape
[N, 4] when code_type is 'encode_center_size'.
This input also can be a 3-D Tensor with shape
[N, M, 4] when code_type is 'decode_center_size'.
Each box is represented as
[xmin, ymin, xmax, ymax]. This tensor can
contain LoD information to represent a batch
of inputs.
code_type(string): The code type used with the target box. It can be
encode_center_size or decode_center_size
box_normalized(int): Whether treat the priorbox as a noramlized box.
Set true by default.
name(string): The name of box coder.
axis(int): Which axis in PriorBox to broadcast for box decode,
for example, if axis is 0 and TargetBox has shape
[N, M, 4] and PriorBox has shape [M, 4], then PriorBox
will broadcast to [N, M, 4] for decoding. It is only valid
when code type is decode_center_size. Set 0 by default.
Returns:
output_box(${output_box_type}): ${output_box_comment}
output_box(Variable): When code_type is 'encode_center_size', the
output tensor of box_coder_op with shape
[N, M, 4] representing the result of N target
boxes encoded with M Prior boxes and variances.
When code_type is 'decode_center_size',
N represents the batch size and M represents
the number of deocded boxes.
Examples:
.. code-block:: python
prior_box = fluid.layers.data(name='prior_box',
shape=[512, 4],
dtype='float32',
append_batch_size=False)
target_box = fluid.layers.data(name='target_box',
shape=[512,81,4],
dtype='float32',
append_batch_size=False)
output = fluid.layers.box_coder(prior_box=prior_box,
prior_box_var=[0.1,0.1,0.2,0.2],
target_box=target_box,
code_type="decode_center_size",
box_normalized=False,
axis=1)
"""
helper = LayerHelper("box_coder", **locals())
......@@ -368,18 +454,22 @@ def box_coder(prior_box,
output_box = helper.create_variable(
name=name, dtype=prior_box.dtype, persistable=False)
helper.append_op(
type="box_coder",
inputs={
"PriorBox": prior_box,
"PriorBoxVar": prior_box_var,
"TargetBox": target_box
},
attrs={
inputs = {"PriorBox": prior_box, "TargetBox": target_box}
attrs = {
"code_type": code_type,
"box_normalized": box_normalized,
"axis": axis
},
}
if isinstance(prior_box_var, Variable):
inputs['PriorBoxVar'] = prior_box_var
elif isinstance(prior_box_var, list):
attrs['variance'] = prior_box_var
else:
raise TypeError("Input variance of box_coder must be Variable or lisz")
helper.append_op(
type="box_coder",
inputs=inputs,
attrs=attrs,
outputs={"OutputBox": output_box})
return output_box
......
......@@ -59,7 +59,7 @@ class TestDetection(unittest.TestCase):
iou = layers.iou_similarity(x=x, y=y)
bcoder = layers.box_coder(
prior_box=x,
prior_box_var=y,
prior_box_var=[0.2, 0.3, 0.3, 0.2],
target_box=z,
code_type='encode_center_size')
self.assertIsNotNone(iou)
......
......@@ -106,9 +106,9 @@ class TestBoxCoderOp(OpTest):
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((10, 4)).astype('float32')
prior_box_var = np.random.random((10, 4)).astype('float32')
target_box = np.random.random((5, 10, 4)).astype('float32')
prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.random.random((81, 4)).astype('float32')
target_box = np.random.random((20, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
......@@ -132,9 +132,9 @@ class TestBoxCoderOpWithOneRankVar(OpTest):
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((6, 4)).astype('float32')
prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((3, 6, 4)).astype('float32')
target_box = np.random.random((20, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
......@@ -159,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest):
def setUp(self):
self.op_type = "box_coder"
lod = [[0, 1, 2, 3, 4, 5]]
prior_box = np.random.random((10, 4)).astype('float32')
prior_box_var = np.ones((10, 4)).astype('float32')
target_box = np.random.random((5, 10, 4)).astype('float32')
prior_box = np.random.random((81, 4)).astype('float32')
prior_box_var = np.ones((81, 4)).astype('float32')
target_box = np.random.random((20, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
......@@ -184,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest):
def setUp(self):
self.op_type = "box_coder"
lod = [[4, 8, 8]]
prior_box = np.random.random((10, 4)).astype('float32')
prior_box_var = np.random.random((10, 4)).astype('float32')
target_box = np.random.random((20, 4)).astype('float32')
lod = [[10, 20, 20]]
prior_box = np.random.random((20, 4)).astype('float32')
prior_box_var = np.random.random((20, 4)).astype('float32')
target_box = np.random.random((50, 4)).astype('float32')
code_type = "EncodeCenterSize"
box_normalized = True
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
......@@ -209,9 +209,9 @@ class TestBoxCoderOpWithAxis(OpTest):
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((5, 4)).astype('float32')
prior_box = np.random.random((30, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((5, 6, 4)).astype('float32')
target_box = np.random.random((30, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
axis = 1
......@@ -231,5 +231,34 @@ class TestBoxCoderOpWithAxis(OpTest):
self.outputs = {'OutputBox': output_box}
class TestBoxCoderOpWithVariance(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_coder"
lod = [[1, 1, 1, 1, 1]]
prior_box = np.random.random((30, 4)).astype('float32')
prior_box_var = np.random.random((4)).astype('float32')
target_box = np.random.random((30, 81, 4)).astype('float32')
code_type = "DecodeCenterSize"
box_normalized = False
axis = 1
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type, box_normalized, axis)
self.inputs = {
'PriorBox': prior_box,
'TargetBox': target_box,
}
self.attrs = {
'code_type': 'decode_center_size',
'box_normalized': False,
'variance': prior_box_var.astype(np.float).flatten(),
'axis': axis
}
self.outputs = {'OutputBox': output_box}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册