diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index 74848005d0bea6e5c818ff999727aa2b8ad51d84..76ef08cb9ad385681375eada7e58721022032db4 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -91,6 +91,10 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "the code type used with the target box") .SetDefault("encode_center_size") .InEnum({"encode_center_size", "decode_center_size"}); + AddAttr("box_normalized", + "(bool, default true) " + "whether treat the priorbox as a noramlized box") + .SetDefault(true); AddOutput("OutputBox", "(LoDTensor or Tensor) " "When code_type is 'encode_center_size', the output tensor of " diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index 8cef8e03439df4ca5b0fa94176a21a36f9eb9f70..fc7eb5d1ed71c19630e96ea0ff0e6fe0962744a8 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -20,15 +20,16 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, const T* target_box_data, const int row, const int col, const int len, - T* output) { + const bool normalized, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int row_idx = idx / col; const int col_idx = idx % col; - T prior_box_width = - prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; - T prior_box_height = - prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; + T prior_box_width = prior_box_data[col_idx * len + 2] - + prior_box_data[col_idx * len] + (normalized == false); + T prior_box_height = prior_box_data[col_idx * len + 3] - + prior_box_data[col_idx * len + 1] + + (normalized == false); T prior_box_center_x = (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; T prior_box_center_y = (prior_box_data[col_idx * len + 3] + @@ -41,10 +42,11 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data, T target_box_center_y = (target_box_data[row_idx * len + 3] + target_box_data[row_idx * len + 1]) / 2; - T target_box_width = - target_box_data[row_idx * len + 2] - target_box_data[row_idx * len]; - T target_box_height = - target_box_data[row_idx * len + 3] - target_box_data[row_idx * len + 1]; + T target_box_width = target_box_data[row_idx * len + 2] - + target_box_data[row_idx * len] + (normalized == false); + T target_box_height = target_box_data[row_idx * len + 3] - + target_box_data[row_idx * len + 1] + + (normalized == false); output[idx * len] = (target_box_center_x - prior_box_center_x) / prior_box_width / prior_box_var_data[col_idx * len]; @@ -63,14 +65,15 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, const T* prior_box_var_data, const T* target_box_data, const int row, const int col, const int len, - T* output) { + const bool normalized, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { const int col_idx = idx % col; - T prior_box_width = - prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; - T prior_box_height = - prior_box_data[col_idx * len + 3] - prior_box_data[col_idx * len + 1]; + T prior_box_width = prior_box_data[col_idx * len + 2] - + prior_box_data[col_idx * len] + (normalized == false); + T prior_box_height = prior_box_data[col_idx * len + 3] - + prior_box_data[col_idx * len + 1] + + (normalized == false); T prior_box_center_x = (prior_box_data[col_idx * len + 2] + prior_box_data[col_idx * len]) / 2; T prior_box_center_y = (prior_box_data[col_idx * len + 3] + @@ -93,8 +96,10 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, output[idx * len] = target_box_center_x - target_box_width / 2; output[idx * len + 1] = target_box_center_y - target_box_height / 2; - output[idx * len + 2] = target_box_center_x + target_box_width / 2; - output[idx * len + 3] = target_box_center_y + target_box_height / 2; + output[idx * len + 2] = + target_box_center_x + target_box_width / 2 - (normalized == false); + output[idx * len + 3] = + target_box_center_y + target_box_height / 2 - (normalized == false); } } @@ -128,14 +133,15 @@ class BoxCoderCUDAKernel : public framework::OpKernel { T* output = output_box->data(); auto code_type = GetBoxCodeType(context.Attr("code_type")); + bool normalized = context.Attr("box_normalized"); if (code_type == BoxCodeType::kEncodeCenterSize) { EncodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - output); + normalized, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { DecodeCenterSizeKernel<<>>( prior_box_data, prior_box_var_data, target_box_data, row, col, len, - output); + normalized, output); } } }; diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h index 77fc6c2b62af42e6526b889aeef2d9bab795baec..3dc68935ac1ea0d3e6ddf2a56bc3aba822c49230 100644 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ b/paddle/fluid/operators/detection/box_coder_op.h @@ -34,7 +34,7 @@ class BoxCoderKernel : public framework::OpKernel { void EncodeCenterSize(const framework::Tensor& target_box, const framework::Tensor& prior_box, const framework::Tensor& prior_box_var, - T* output) const { + const bool normalized, T* output) const { int64_t row = target_box.dims()[0]; int64_t col = prior_box.dims()[0]; int64_t len = prior_box.dims()[1]; @@ -44,10 +44,11 @@ class BoxCoderKernel : public framework::OpKernel { for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { - T prior_box_width = - prior_box_data[j * len + 2] - prior_box_data[j * len]; - T prior_box_height = - prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; + T prior_box_width = prior_box_data[j * len + 2] - + prior_box_data[j * len] + (normalized == false); + T prior_box_height = prior_box_data[j * len + 3] - + prior_box_data[j * len + 1] + + (normalized == false); T prior_box_center_x = (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; T prior_box_center_y = @@ -57,10 +58,11 @@ class BoxCoderKernel : public framework::OpKernel { (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; T target_box_center_y = (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2; - T target_box_width = - target_box_data[i * len + 2] - target_box_data[i * len]; - T target_box_height = - target_box_data[i * len + 3] - target_box_data[i * len + 1]; + T target_box_width = target_box_data[i * len + 2] - + target_box_data[i * len] + (normalized == false); + T target_box_height = target_box_data[i * len + 3] - + target_box_data[i * len + 1] + + (normalized == false); size_t offset = i * col * len + j * len; output[offset] = (target_box_center_x - prior_box_center_x) / @@ -79,7 +81,7 @@ class BoxCoderKernel : public framework::OpKernel { void DecodeCenterSize(const framework::Tensor& target_box, const framework::Tensor& prior_box, const framework::Tensor& prior_box_var, - T* output) const { + const bool normalized, T* output) const { int64_t row = target_box.dims()[0]; int64_t col = prior_box.dims()[0]; int64_t len = prior_box.dims()[1]; @@ -91,10 +93,11 @@ class BoxCoderKernel : public framework::OpKernel { for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { size_t offset = i * col * len + j * len; - T prior_box_width = - prior_box_data[j * len + 2] - prior_box_data[j * len]; - T prior_box_height = - prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; + T prior_box_width = prior_box_data[j * len + 2] - + prior_box_data[j * len] + (normalized == false); + T prior_box_height = prior_box_data[j * len + 3] - + prior_box_data[j * len + 1] + + (normalized == false); T prior_box_center_x = (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; T prior_box_center_y = @@ -116,8 +119,10 @@ class BoxCoderKernel : public framework::OpKernel { output[offset] = target_box_center_x - target_box_width / 2; output[offset + 1] = target_box_center_y - target_box_height / 2; - output[offset + 2] = target_box_center_x + target_box_width / 2; - output[offset + 3] = target_box_center_y + target_box_height / 2; + output[offset + 2] = + target_box_center_x + target_box_width / 2 - (normalized == false); + output[offset + 3] = + target_box_center_y + target_box_height / 2 - (normalized == false); } } } @@ -139,11 +144,14 @@ class BoxCoderKernel : public framework::OpKernel { output_box->mutable_data({row, col, len}, context.GetPlace()); auto code_type = GetBoxCodeType(context.Attr("code_type")); + bool normalized = context.Attr("box_normalized"); T* output = output_box->data(); if (code_type == BoxCodeType::kEncodeCenterSize) { - EncodeCenterSize(*target_box, *prior_box, *prior_box_var, output); + EncodeCenterSize(*target_box, *prior_box, *prior_box_var, normalized, + output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { - DecodeCenterSize(*target_box, *prior_box, *prior_box_var, output); + DecodeCenterSize(*target_box, *prior_box, *prior_box_var, normalized, + output); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index 56f5af91d8e58086c12fde6948229675569aa271..a31b7ea322ff0a351308bea5608b2af9b60ac582 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -19,7 +19,8 @@ import math from op_test import OpTest -def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): +def box_coder(target_box, prior_box, prior_box_var, output_box, code_type, + box_normalized): prior_box_x = ( (prior_box[:, 2] + prior_box[:, 0]) / 2).reshape(1, prior_box.shape[0]) prior_box_y = ( @@ -30,6 +31,9 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): (prior_box[:, 3] - prior_box[:, 1])).reshape(1, prior_box.shape[0]) prior_box_var = prior_box_var.reshape(1, prior_box_var.shape[0], prior_box_var.shape[1]) + if not box_normalized: + prior_box_height = prior_box_height + 1 + prior_box_width = prior_box_width + 1 if (code_type == "EncodeCenterSize"): target_box_x = ((target_box[:, 2] + target_box[:, 0]) / 2).reshape( @@ -40,6 +44,9 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): target_box.shape[0], 1) target_box_height = ((target_box[:, 3] - target_box[:, 1])).reshape( target_box.shape[0], 1) + if not box_normalized: + target_box_height = target_box_height + 1 + target_box_width = target_box_width + 1 output_box[:,:,0] = (target_box_x - prior_box_x) / prior_box_width / \ prior_box_var[:,:,0] @@ -64,9 +71,13 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): output_box[:, :, 1] = target_box_y - target_box_height / 2 output_box[:, :, 2] = target_box_x + target_box_width / 2 output_box[:, :, 3] = target_box_y + target_box_height / 2 + if not box_normalized: + output_box[:, :, 2] = output_box[:, :, 2] - 1 + output_box[:, :, 3] = output_box[:, :, 3] - 1 -def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type): +def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type, + box_normalized): n = target_box.shape[0] m = prior_box.shape[0] output_box = np.zeros((n, m, 4), dtype=np.float32) @@ -74,11 +85,11 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type): if (code_type == "EncodeCenterSize"): box_coder(target_box[lod[i]:lod[i + 1], :], prior_box, prior_box_var, output_box[lod[i]:lod[i + 1], :, :], - code_type) + code_type, box_normalized) elif (code_type == "DecodeCenterSize"): box_coder(target_box[lod[i]:lod[i + 1], :, :], prior_box, prior_box_var, output_box[lod[i]:lod[i + 1], :, :], - code_type) + code_type, box_normalized) return output_box @@ -93,15 +104,19 @@ class TestBoxCoderOp(OpTest): prior_box_var = np.random.random((10, 4)).astype('float32') target_box = np.random.random((5, 10, 4)).astype('float32') code_type = "DecodeCenterSize" + box_normalized = False output_box = batch_box_coder(prior_box, prior_box_var, target_box, - lod[0], code_type) + lod[0], code_type, box_normalized) self.inputs = { 'PriorBox': prior_box, 'PriorBoxVar': prior_box_var, 'TargetBox': target_box, } - self.attrs = {'code_type': 'decode_center_size'} + self.attrs = { + 'code_type': 'decode_center_size', + 'box_normalized': False + } self.outputs = {'OutputBox': output_box} @@ -116,15 +131,16 @@ class TestBoxCoderOpWithLoD(OpTest): prior_box_var = np.random.random((10, 4)).astype('float32') target_box = np.random.random((20, 4)).astype('float32') code_type = "EncodeCenterSize" + box_normalized = True output_box = batch_box_coder(prior_box, prior_box_var, target_box, - lod[0], code_type) + lod[0], code_type, box_normalized) self.inputs = { 'PriorBox': prior_box, 'PriorBoxVar': prior_box_var, 'TargetBox': (target_box, lod), } - self.attrs = {'code_type': 'encode_center_size'} + self.attrs = {'code_type': 'encode_center_size', 'box_normalized': True} self.outputs = {'OutputBox': output_box}