From e84615bae605f750b8be9694a5967333d3e190af Mon Sep 17 00:00:00 2001 From: Yuan Gao Date: Wed, 28 Feb 2018 18:57:46 +0800 Subject: [PATCH] Fix box coder op (#8647) * fix ssd problems * fix box decoder op * fix dimension problem in detection tests * update detection doc * Update detection doc * Update detection doc * update detection doc * update detection doc --- paddle/fluid/operators/box_coder_op.cc | 48 ++++++++++------- paddle/fluid/operators/box_coder_op.cu | 10 ++-- paddle/fluid/operators/box_coder_op.h | 10 ++-- python/paddle/fluid/layers/detection.py | 51 ++++++++++++------- python/paddle/fluid/tests/test_detection.py | 8 ++- .../tests/unittests/test_box_coder_op.py | 17 ++++--- 6 files changed, 85 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/operators/box_coder_op.cc b/paddle/fluid/operators/box_coder_op.cc index 1fc201286..eccdd408a 100644 --- a/paddle/fluid/operators/box_coder_op.cc +++ b/paddle/fluid/operators/box_coder_op.cc @@ -37,12 +37,19 @@ class BoxCoderOp : public framework::OperatorWithKernel { "The rank of Input of PriorBoxVar must be 2"); PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims); - PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, - "The rank of Input of TargetBox must be 2"); - PADDLE_ENFORCE_EQ(target_box_dims[1], 4, - "The shape of TargetBox is [M, 4]"); - GetBoxCodeType(ctx->Attrs().Get("code_type")); + auto code_type = GetBoxCodeType(ctx->Attrs().Get("code_type")); + if (code_type == BoxCodeType::kEncodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, + "The rank of Input of TargetBox must be 2"); + PADDLE_ENFORCE_EQ(target_box_dims[1], 4, + "The shape of TargetBox is [M, 4]"); + } else if (code_type == BoxCodeType::kDecodeCenterSize) { + PADDLE_ENFORCE_EQ(target_box_dims.size(), 3, + "The rank of Input of TargetBox must be 3"); + PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]); + PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]); + } ctx->SetOutputDim( "OutputBox", @@ -70,25 +77,28 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { "of variance."); AddInput( "TargetBox", - "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " - "[N, 4], each box is represented as [xmin, ymin, xmax, ymax], " - "[xmin, ymin] is the left top coordinate of the box if the input " - "is image feature map, they are close to the origin of the coordinate " - "system. [xmax, ymax] is the right bottom coordinate of the box. " - "This tensor can contain LoD information to represent a batch " - "of inputs. One instance of this batch can contain different " - "numbers of entities."); + "(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape " + "[N, 4] when code_type is 'encode_center_size'. This input also can " + "be a 3-D Tensor with shape [N, M, 4] when code_type is " + "'decode_center_size'. [N, 4], each box is represented as " + "[xmin, ymin, xmax, ymax], [xmin, ymin] is the left top coordinate " + "of the box if the input is image feature map, they are close to " + "the origin of the coordinate system. [xmax, ymax] is the right " + "bottom coordinate of the box. This tensor can contain LoD " + "information to represent a batch of inputs. One instance of this " + "batch can contain different numbers of entities."); AddAttr("code_type", "(string, default encode_center_size) " "the code type used with the target box") .SetDefault("encode_center_size") .InEnum({"encode_center_size", "decode_center_size"}); - AddOutput( - "OutputBox", - "(LoDTensor or Tensor) " - "(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] " - "representing the result of N target boxes encoded/decoded with " - "M Prior boxes and variances."); + AddOutput("OutputBox", + "(LoDTensor or Tensor) " + "When code_type is 'encode_center_size', the output tensor of " + "box_coder_op with shape [N, M, 4] representing the result of N " + "target boxes encoded with M Prior boxes and variances. When " + "code_type is 'decode_center_size', N represents the batch size " + "and M represents the number of deocded boxes."); AddComment(R"DOC( Bounding Box Coder Operator. diff --git a/paddle/fluid/operators/box_coder_op.cu b/paddle/fluid/operators/box_coder_op.cu index 7ab242edf..0944e9c95 100644 --- a/paddle/fluid/operators/box_coder_op.cu +++ b/paddle/fluid/operators/box_coder_op.cu @@ -66,7 +66,6 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, T* output) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < row * col) { - const int row_idx = idx / col; const int col_idx = idx % col; T prior_box_width = prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len]; @@ -79,17 +78,16 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data, 2; T target_box_width = exp(prior_box_var_data[col_idx * len + 2] * - target_box_data[row_idx * len + 2]) * + target_box_data[idx * len + 2]) * prior_box_width; T target_box_height = exp(prior_box_var_data[col_idx * len + 3] * - target_box_data[row_idx * len + 3]) * + target_box_data[idx * len + 3]) * prior_box_height; T target_box_center_x = prior_box_var_data[col_idx * len] * - target_box_data[row_idx * len] * - prior_box_width + + target_box_data[idx * len] * prior_box_width + prior_box_center_x; T target_box_center_y = prior_box_var_data[col_idx * len + 1] * - target_box_data[row_idx * len + 1] * + target_box_data[idx * len + 1] * prior_box_height + prior_box_center_y; diff --git a/paddle/fluid/operators/box_coder_op.h b/paddle/fluid/operators/box_coder_op.h index 5e105aff5..3c7cac1cd 100644 --- a/paddle/fluid/operators/box_coder_op.h +++ b/paddle/fluid/operators/box_coder_op.h @@ -89,6 +89,7 @@ class BoxCoderKernel : public framework::OpKernel { for (int64_t i = 0; i < row; ++i) { for (int64_t j = 0; j < col; ++j) { + size_t offset = i * col * len + j * len; T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len]; T prior_box_height = @@ -99,20 +100,19 @@ class BoxCoderKernel : public framework::OpKernel { (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; T target_box_center_x = prior_box_var_data[j * len] * - target_box_data[i * len] * prior_box_width + + target_box_data[offset] * prior_box_width + prior_box_center_x; T target_box_center_y = prior_box_var_data[j * len + 1] * - target_box_data[i * len + 1] * + target_box_data[offset + 1] * prior_box_height + prior_box_center_y; T target_box_width = std::exp(prior_box_var_data[j * len + 2] * - target_box_data[i * len + 2]) * + target_box_data[offset + 2]) * prior_box_width; T target_box_height = std::exp(prior_box_var_data[j * len + 3] * - target_box_data[i * len + 3]) * + target_box_data[offset + 3]) * prior_box_height; - size_t offset = i * col * len + j * len; output[offset] = target_box_center_x - target_box_width / 2; output[offset + 1] = target_box_center_y - target_box_height / 2; output[offset + 2] = target_box_center_x + target_box_width / 2; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index f380f5c00..a077c0ce3 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -43,8 +43,8 @@ for _OP in set(__auto__): globals()[_OP] = generate_layer_fn(_OP) -def detection_output(scores, - loc, +def detection_output(loc, + scores, prior_box, prior_box_var, background_label=0, @@ -61,14 +61,14 @@ def detection_output(scores, be zero if there is no valid bounding box. Args: - scores(Variable): A 3-D Tensor with shape [N, C, M] represents the - predicted confidence predictions. N is the batch size, C is the - class number, M is number of bounding boxes. For each category - there are total M scores which corresponding M bounding boxes. loc(Variable): A 3-D Tensor with shape [N, M, 4] represents the predicted locations of M bounding bboxes. N is the batch size, and each bounding box has four coordinate values and the layout is [xmin, ymin, xmax, ymax]. + scores(Variable): A 3-D Tensor with shape [N, M, C] represents the + predicted confidence predictions. N is the batch size, C is the + class number, M is number of bounding boxes. For each category + there are total M scores which corresponding M bounding boxes. prior_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes, each box is represented as [xmin, ymin, xmax, ymax], [xmin, ymin] is the left top coordinate of the anchor box, @@ -100,7 +100,7 @@ def detection_output(scores, append_batch_size=False, dtype='float32') pbv = layers.data(name='prior_box_var', shape=[10, 4], append_batch_size=False, dtype='float32') - loc = layers.data(name='target_box', shape=[21, 4], + loc = layers.data(name='target_box', shape=[2, 21, 4], append_batch_size=False, dtype='float32') scores = layers.data(name='scores', shape=[2, 21, 10], append_batch_size=False, dtype='float32') @@ -109,7 +109,6 @@ def detection_output(scores, prior_box=pb, prior_box_var=pbv) """ - helper = LayerHelper("detection_output", **locals()) decoded_box = box_coder( prior_box=prior_box, @@ -118,6 +117,7 @@ def detection_output(scores, code_type='decode_center_size') nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype) + scores = nn.transpose(scores, perm=[0, 2, 1]) helper.append_op( type="multiclass_nms", inputs={'Scores': scores, @@ -595,12 +595,13 @@ def multi_box_head(inputs, name(str): Name of the prior box layer. Default: None. Returns: - mbox_loc(list): The predicted boxes' location of the inputs. - The layout of each element is [N, H, W, Priors]. Priors - is the number of predicted boxof each position of each input. - mbox_conf(list): The predicted boxes' confidence of the inputs. - The layout of each element is [N, H, W, Priors]. Priors - is the number of predicted box of each position of each input. + mbox_loc(Variable): The predicted boxes' location of the inputs. + The layout is [N, H*W*Priors, 4]. where Priors + is the number of predicted boxes each position of each input. + mbox_conf(Variable): The predicted boxes' confidence of the inputs. + The layout is [N, H*W*Priors, C]. where Priors + is the number of predicted boxes each position of each input + and C is the number of Classes. boxes(Variable): the output prior boxes of PriorBox. The layout is [num_priors, 4]. num_priors is the total box count of each position of inputs. @@ -751,7 +752,7 @@ def multi_box_head(inputs, num_boxes = box.shape[2] # get box_loc - num_loc_output = num_boxes * num_classes * 4 + num_loc_output = num_boxes * 4 mbox_loc = nn.conv2d( input=input, num_filters=num_loc_output, @@ -760,7 +761,12 @@ def multi_box_head(inputs, stride=stride) mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) - mbox_locs.append(mbox_loc) + new_shape = [ + mbox_loc.shape[0], + mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4 + ] + mbox_loc_flatten = ops.reshape(mbox_loc, shape=new_shape) + mbox_locs.append(mbox_loc_flatten) # get conf_loc num_conf_output = num_boxes * num_classes @@ -771,11 +777,18 @@ def multi_box_head(inputs, padding=pad, stride=stride) conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1]) - mbox_confs.append(conf_loc) + new_shape = [ + conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] * + conf_loc.shape[3] / num_classes, num_classes + ] + conf_loc_flatten = ops.reshape(conf_loc, shape=new_shape) + mbox_confs.append(conf_loc_flatten) if len(box_results) == 1: box = box_results[0] var = var_results[0] + mbox_locs_concat = mbox_locs[0] + mbox_confs_concat = mbox_confs[0] else: reshaped_boxes = [] reshaped_vars = [] @@ -785,5 +798,7 @@ def multi_box_head(inputs, box = tensor.concat(reshaped_boxes) var = tensor.concat(reshaped_vars) + mbox_locs_concat = tensor.concat(mbox_locs, axis=1) + mbox_confs_concat = tensor.concat(mbox_confs, axis=1) - return mbox_locs, mbox_confs, box, var + return mbox_locs_concat, mbox_confs_concat, box, var diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index fc2578649..0d2d653c0 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -35,12 +35,12 @@ class TestDetection(unittest.TestCase): dtype='float32') loc = layers.data( name='target_box', - shape=[20, 4], + shape=[2, 10, 4], append_batch_size=False, dtype='float32') scores = layers.data( name='scores', - shape=[2, 20, 10], + shape=[2, 10, 20], append_batch_size=False, dtype='float32') out = layers.detection_output( @@ -117,9 +117,7 @@ class TestMultiBoxHead(unittest.TestCase): assert len(box.shape) == 2 assert box.shape == var.shape assert box.shape[1] == 4 - - for loc, conf in zip(mbox_locs, mbox_confs): - assert loc.shape[1:3] == conf.shape[1:3] + assert mbox_locs.shape[1] == mbox_confs.shape[1] def multi_box_head_output(self, data_shape): images = fluid.layers.data( diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index b83917609..56f5af91d 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -51,8 +51,6 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): prior_box_var[:,:,3] elif (code_type == "DecodeCenterSize"): - target_box = target_box.reshape(target_box.shape[0], 1, - target_box.shape[1]) target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \ prior_box_width + prior_box_x target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \ @@ -61,6 +59,7 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type): prior_box_width target_box_height = np.exp(prior_box_var[:,:,3] * target_box[:,:,3]) * \ prior_box_height + output_box[:, :, 0] = target_box_x - target_box_width / 2 output_box[:, :, 1] = target_box_y - target_box_height / 2 output_box[:, :, 2] = target_box_x + target_box_width / 2 @@ -72,8 +71,14 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type): m = prior_box.shape[0] output_box = np.zeros((n, m, 4), dtype=np.float32) for i in range(len(lod) - 1): - box_coder(target_box[lod[i]:lod[i + 1], :], prior_box, prior_box_var, - output_box[lod[i]:lod[i + 1], :, :], code_type) + if (code_type == "EncodeCenterSize"): + box_coder(target_box[lod[i]:lod[i + 1], :], prior_box, + prior_box_var, output_box[lod[i]:lod[i + 1], :, :], + code_type) + elif (code_type == "DecodeCenterSize"): + box_coder(target_box[lod[i]:lod[i + 1], :, :], prior_box, + prior_box_var, output_box[lod[i]:lod[i + 1], :, :], + code_type) return output_box @@ -83,10 +88,10 @@ class TestBoxCoderOp(OpTest): def setUp(self): self.op_type = "box_coder" - lod = [[0, 20]] + lod = [[0, 1, 2, 3, 4, 5]] prior_box = np.random.random((10, 4)).astype('float32') prior_box_var = np.random.random((10, 4)).astype('float32') - target_box = np.random.random((20, 4)).astype('float32') + target_box = np.random.random((5, 10, 4)).astype('float32') code_type = "DecodeCenterSize" output_box = batch_box_coder(prior_box, prior_box_var, target_box, lod[0], code_type) -- GitLab