diff --git a/paddle/fluid/operators/detection/box_coder_op.cc b/paddle/fluid/operators/detection/box_coder_op.cc index fdcff62e1fe59b3a2f4925bdff98632f71220abb..0a51d50e06176e713922837861f2102c9ee8a899 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cc +++ b/paddle/fluid/operators/detection/box_coder_op.cc @@ -38,20 +38,12 @@ class BoxCoderOp : public framework::OperatorWithKernel { "The shape of PriorBox is [N, 4]"); if (ctx->HasInput("PriorBoxVar")) { auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); - PADDLE_ENFORCE( - prior_box_var_dims.size() == 1 || prior_box_var_dims.size() == 2, - "Input(PriorBoxVar) of BoxCoderOp should be 1 or 2."); - if (prior_box_var_dims.size() == 1) { - PADDLE_ENFORCE_EQ( - prior_box_var_dims[0], 4, - "The 1st dimension of Input(PriorBoxVar) should be 4" - "when the rank is 1."); - } else { - PADDLE_ENFORCE_EQ( - prior_box_dims, prior_box_var_dims, - "The dimension of Input(PriorBoxVar) should be equal to" - "the dimension of Input(PriorBox when the rank is 2.)"); - } + PADDLE_ENFORCE(prior_box_var_dims.size() == 2, + "Input(PriorBoxVar) of BoxCoderOp should be 2."); + PADDLE_ENFORCE_EQ( + prior_box_dims, prior_box_var_dims, + "The dimension of Input(PriorBoxVar) should be equal to" + "the dimension of Input(PriorBox) when the rank is 2."); } } diff --git a/paddle/fluid/operators/detection/box_coder_op.cu b/paddle/fluid/operators/detection/box_coder_op.cu index e078af3eb478a8bebc6a7fc6460d169d803a3c4b..19a5bb90fa828899ad6270c051090dd3662aeed8 100644 --- a/paddle/fluid/operators/detection/box_coder_op.cu +++ b/paddle/fluid/operators/detection/box_coder_op.cu @@ -56,10 +56,7 @@ __global__ void EncodeCenterSizeKernel( output[idx * len + 2] = log(fabs(target_box_width / prior_box_width)); output[idx * len + 3] = log(fabs(target_box_height / prior_box_height)); if (prior_box_var_data) { - int prior_var_offset = 0; - if (prior_box_var_size == 2) { - prior_var_offset = col_idx * len; - } + int prior_var_offset = col_idx * len; output[idx * len] /= prior_box_var_data[prior_var_offset]; output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1]; output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2]; @@ -99,10 +96,7 @@ __global__ void DecodeCenterSizeKernel( T box_var_x = T(1), box_var_y = T(1); T box_var_w = T(1), box_var_h = T(1); if (prior_box_var_data) { - int prior_var_offset = 0; - if (prior_box_var_size == 2) { - prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; - } + int prior_var_offset = axis == 0 ? col_idx * len : row_idx * len; box_var_x = prior_box_var_data[prior_var_offset]; box_var_y = prior_box_var_data[prior_var_offset + 1]; box_var_w = prior_box_var_data[prior_var_offset + 2]; diff --git a/paddle/fluid/operators/detection/box_coder_op.h b/paddle/fluid/operators/detection/box_coder_op.h index a0b1faf7bdc7001eba2d92b4d03fbaf9feb7bcbb..6d406f8196f9964c85bb94541fa7a7a23857539b 100644 --- a/paddle/fluid/operators/detection/box_coder_op.h +++ b/paddle/fluid/operators/detection/box_coder_op.h @@ -79,10 +79,7 @@ class BoxCoderKernel : public framework::OpKernel { output[offset + 3] = std::log(std::fabs(target_box_height / prior_box_height)); if (prior_box_var) { - int prior_var_offset = 0; - if (prior_box_var->dims().size() == 2) { - prior_var_offset = j * len; - } + int prior_var_offset = j * len; output[offset] /= prior_box_var_data[prior_var_offset]; output[offset + 1] /= prior_box_var_data[prior_var_offset + 1]; output[offset + 2] /= prior_box_var_data[prior_var_offset + 2]; @@ -95,11 +92,12 @@ class BoxCoderKernel : public framework::OpKernel { } } } + template void DecodeCenterSize(const framework::Tensor* target_box, const framework::Tensor* prior_box, const framework::Tensor* prior_box_var, - const bool normalized, const int axis, - const std::vector variance, T* output) const { + const bool normalized, std::vector variance, + T* output) const { int64_t row = target_box->dims()[0]; int64_t col = target_box->dims()[1]; int64_t len = target_box->dims()[2]; @@ -107,19 +105,17 @@ class BoxCoderKernel : public framework::OpKernel { auto* target_box_data = target_box->data(); auto* prior_box_data = prior_box->data(); const T* prior_box_var_data = nullptr; - if (prior_box_var) prior_box_var_data = prior_box_var->data(); + if (var_size == 2) prior_box_var_data = prior_box_var->data(); int prior_box_offset = 0; + T var_data[4] = {1., 1., 1., 1.}; + T* var_ptr = var_data; #ifdef PADDLE_WITH_MKLML #pragma omp parallel for collapse(2) #endif for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { size_t offset = i * col * len + j * len; - if (axis == 0) { - prior_box_offset = j * len; - } else if (axis == 1) { - prior_box_offset = i * len; - } + prior_box_offset = axis == 0 ? j * len : i * len; T prior_box_width = prior_box_data[prior_box_offset + 2] - prior_box_data[prior_box_offset] + (normalized == false); @@ -133,26 +129,18 @@ class BoxCoderKernel : public framework::OpKernel { T target_box_center_x = 0, target_box_center_y = 0; T target_box_width = 0, target_box_height = 0; - T box_var_x = T(1), box_var_y = T(1); - T box_var_w = T(1), box_var_h = T(1); - if (prior_box_var) { - int prior_var_offset = 0; - if (prior_box_var->dims().size() == 2) { - if (axis == 0) - prior_var_offset = j * len; - else if (axis == 1) - prior_var_offset = i * len; - } - box_var_x = prior_box_var_data[prior_var_offset]; - box_var_y = prior_box_var_data[prior_var_offset + 1]; - box_var_w = prior_box_var_data[prior_var_offset + 2]; - box_var_h = prior_box_var_data[prior_var_offset + 3]; - } else if (!(variance.empty())) { - box_var_x = static_cast(variance[0]); - box_var_y = static_cast(variance[1]); - box_var_w = static_cast(variance[2]); - box_var_h = static_cast(variance[3]); + int prior_var_offset = axis == 0 ? j * len : i * len; + if (var_size == 2) { + std::memcpy(var_ptr, prior_box_var_data + prior_var_offset, + 4 * sizeof(T)); + } else if (var_size == 1) { + var_ptr = reinterpret_cast(variance.data()); } + T box_var_x = *var_ptr; + T box_var_y = *(var_ptr + 1); + T box_var_w = *(var_ptr + 2); + T box_var_h = *(var_ptr + 3); + target_box_center_x = box_var_x * target_box_data[offset] * prior_box_width + prior_box_center_x; @@ -211,8 +199,31 @@ class BoxCoderKernel : public framework::OpKernel { EncodeCenterSize(target_box, prior_box, prior_box_var, normalized, variance, output); } else if (code_type == BoxCodeType::kDecodeCenterSize) { - DecodeCenterSize(target_box, prior_box, prior_box_var, normalized, axis, - variance, output); + if (prior_box_var) { + if (axis == 0) { + DecodeCenterSize<0, 2>(target_box, prior_box, prior_box_var, + normalized, variance, output); + } else { + DecodeCenterSize<1, 2>(target_box, prior_box, prior_box_var, + normalized, variance, output); + } + } else if (!(variance.empty())) { + if (axis == 0) { + DecodeCenterSize<0, 1>(target_box, prior_box, prior_box_var, + normalized, variance, output); + } else { + DecodeCenterSize<1, 1>(target_box, prior_box, prior_box_var, + normalized, variance, output); + } + } else { + if (axis == 0) { + DecodeCenterSize<0, 0>(target_box, prior_box, prior_box_var, + normalized, variance, output); + } else { + DecodeCenterSize<1, 0>(target_box, prior_box, prior_box_var, + normalized, variance, output); + } + } } } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index c983e2a44b25c5943df5e822e2e363b2557a6ac3..3b43ae0b9cb63a9f4708a680cb1021d74c197550 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -397,10 +397,10 @@ def box_coder(prior_box, input is image feature map, they are close to the origin of the coordinate system. [xmax, ymax] is the right bottom coordinate of the anchor box. - prior_box_var(Variable|list): prior_box_var supports two types of input. - One is variable with shape [M, 4] holds M group. - The other one is list consist of 4 elements - shared by all boxes. + prior_box_var(Variable|list|None): prior_box_var supports two types + of input. One is variable with shape [M, 4] + holds M group. The other one is list consist of + 4 elements shared by all boxes. target_box(Variable): This input can be a 2-D LoDTensor with shape [N, 4] when code_type is 'encode_center_size'. This input also can be a 3-D Tensor with shape diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index 6156268bf25ada310a3d22242ecff4b9cdf1759a..220bffebe83925c60af65aa9594ddd8a29c38145 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -34,7 +34,9 @@ def box_decoder(t_box, p_box, pb_v, output_box, norm, axis=0): pb_y = pb_y.reshape(shape) if pb_v.ndim == 2: - pb_v = pb_v.reshape(1, pb_v.shape[0], pb_v.shape[1]) + var_shape = (1, pb_v.shape[0], pb_v.shape[1]) if axis == 0 else ( + pb_v.shape[0], 1, pb_v.shape[1]) + pb_v = pb_v.reshape(var_shape) if pb_v.ndim == 1: tb_x = pb_v[0] * t_box[:, :, 0] * pb_w + pb_x tb_y = pb_v[1] * t_box[:, :, 1] * pb_h + pb_y @@ -125,33 +127,6 @@ class TestBoxCoderOp(OpTest): self.outputs = {'OutputBox': output_box} -class TestBoxCoderOpWithOneRankVar(OpTest): - def test_check_output(self): - self.check_output() - - def setUp(self): - self.op_type = "box_coder" - lod = [[1, 1, 1, 1, 1]] - prior_box = np.random.random((81, 4)).astype('float32') - prior_box_var = np.random.random((4)).astype('float32') - target_box = np.random.random((20, 81, 4)).astype('float32') - code_type = "DecodeCenterSize" - box_normalized = False - output_box = batch_box_coder(prior_box, prior_box_var, target_box, - lod[0], code_type, box_normalized) - - self.inputs = { - 'PriorBox': prior_box, - 'PriorBoxVar': prior_box_var, - 'TargetBox': target_box, - } - self.attrs = { - 'code_type': 'decode_center_size', - 'box_normalized': False - } - self.outputs = {'OutputBox': output_box} - - class TestBoxCoderOpWithoutBoxVar(OpTest): def test_check_output(self): self.check_output() @@ -210,7 +185,7 @@ class TestBoxCoderOpWithAxis(OpTest): self.op_type = "box_coder" lod = [[1, 1, 1, 1, 1]] prior_box = np.random.random((30, 4)).astype('float32') - prior_box_var = np.random.random((4)).astype('float32') + prior_box_var = np.random.random((30, 4)).astype('float32') target_box = np.random.random((30, 81, 4)).astype('float32') code_type = "DecodeCenterSize" box_normalized = False