提交 e84615ba 编写于 作者: Y Yuan Gao 提交者: qingqing01

Fix box coder op (#8647)

* fix ssd problems

* fix box decoder op

* fix dimension problem in detection tests

* update detection doc

* Update detection doc

* Update detection doc

* update detection doc

* update detection doc
上级 9344e4e3
......@@ -37,12 +37,19 @@ class BoxCoderOp : public framework::OperatorWithKernel {
"The rank of Input of PriorBoxVar must be 2");
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]");
PADDLE_ENFORCE_EQ(prior_box_dims, prior_box_var_dims);
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
"The rank of Input of TargetBox must be 2");
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
"The shape of TargetBox is [M, 4]");
GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
auto code_type = GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
if (code_type == BoxCodeType::kEncodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
"The rank of Input of TargetBox must be 2");
PADDLE_ENFORCE_EQ(target_box_dims[1], 4,
"The shape of TargetBox is [M, 4]");
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
PADDLE_ENFORCE_EQ(target_box_dims.size(), 3,
"The rank of Input of TargetBox must be 3");
PADDLE_ENFORCE_EQ(target_box_dims[1], prior_box_dims[0]);
PADDLE_ENFORCE_EQ(target_box_dims[2], prior_box_dims[1]);
}
ctx->SetOutputDim(
"OutputBox",
......@@ -70,25 +77,28 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"of variance.");
AddInput(
"TargetBox",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
"[N, 4], each box is represented as [xmin, ymin, xmax, ymax], "
"[xmin, ymin] is the left top coordinate of the box if the input "
"is image feature map, they are close to the origin of the coordinate "
"system. [xmax, ymax] is the right bottom coordinate of the box. "
"This tensor can contain LoD information to represent a batch "
"of inputs. One instance of this batch can contain different "
"numbers of entities.");
"(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape "
"[N, 4] when code_type is 'encode_center_size'. This input also can "
"be a 3-D Tensor with shape [N, M, 4] when code_type is "
"'decode_center_size'. [N, 4], each box is represented as "
"[xmin, ymin, xmax, ymax], [xmin, ymin] is the left top coordinate "
"of the box if the input is image feature map, they are close to "
"the origin of the coordinate system. [xmax, ymax] is the right "
"bottom coordinate of the box. This tensor can contain LoD "
"information to represent a batch of inputs. One instance of this "
"batch can contain different numbers of entities.");
AddAttr<std::string>("code_type",
"(string, default encode_center_size) "
"the code type used with the target box")
.SetDefault("encode_center_size")
.InEnum({"encode_center_size", "decode_center_size"});
AddOutput(
"OutputBox",
"(LoDTensor or Tensor) "
"(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] "
"representing the result of N target boxes encoded/decoded with "
"M Prior boxes and variances.");
AddOutput("OutputBox",
"(LoDTensor or Tensor) "
"When code_type is 'encode_center_size', the output tensor of "
"box_coder_op with shape [N, M, 4] representing the result of N "
"target boxes encoded with M Prior boxes and variances. When "
"code_type is 'decode_center_size', N represents the batch size "
"and M represents the number of deocded boxes.");
AddComment(R"DOC(
Bounding Box Coder Operator.
......
......@@ -66,7 +66,6 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < row * col) {
const int row_idx = idx / col;
const int col_idx = idx % col;
T prior_box_width =
prior_box_data[col_idx * len + 2] - prior_box_data[col_idx * len];
......@@ -79,17 +78,16 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
2;
T target_box_width = exp(prior_box_var_data[col_idx * len + 2] *
target_box_data[row_idx * len + 2]) *
target_box_data[idx * len + 2]) *
prior_box_width;
T target_box_height = exp(prior_box_var_data[col_idx * len + 3] *
target_box_data[row_idx * len + 3]) *
target_box_data[idx * len + 3]) *
prior_box_height;
T target_box_center_x = prior_box_var_data[col_idx * len] *
target_box_data[row_idx * len] *
prior_box_width +
target_box_data[idx * len] * prior_box_width +
prior_box_center_x;
T target_box_center_y = prior_box_var_data[col_idx * len + 1] *
target_box_data[row_idx * len + 1] *
target_box_data[idx * len + 1] *
prior_box_height +
prior_box_center_y;
......
......@@ -89,6 +89,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
size_t offset = i * col * len + j * len;
T prior_box_width =
prior_box_data[j * len + 2] - prior_box_data[j * len];
T prior_box_height =
......@@ -99,20 +100,19 @@ class BoxCoderKernel : public framework::OpKernel<T> {
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
T target_box_center_x = prior_box_var_data[j * len] *
target_box_data[i * len] * prior_box_width +
target_box_data[offset] * prior_box_width +
prior_box_center_x;
T target_box_center_y = prior_box_var_data[j * len + 1] *
target_box_data[i * len + 1] *
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
T target_box_width = std::exp(prior_box_var_data[j * len + 2] *
target_box_data[i * len + 2]) *
target_box_data[offset + 2]) *
prior_box_width;
T target_box_height = std::exp(prior_box_var_data[j * len + 3] *
target_box_data[i * len + 3]) *
target_box_data[offset + 3]) *
prior_box_height;
size_t offset = i * col * len + j * len;
output[offset] = target_box_center_x - target_box_width / 2;
output[offset + 1] = target_box_center_y - target_box_height / 2;
output[offset + 2] = target_box_center_x + target_box_width / 2;
......
......@@ -43,8 +43,8 @@ for _OP in set(__auto__):
globals()[_OP] = generate_layer_fn(_OP)
def detection_output(scores,
loc,
def detection_output(loc,
scores,
prior_box,
prior_box_var,
background_label=0,
......@@ -61,14 +61,14 @@ def detection_output(scores,
be zero if there is no valid bounding box.
Args:
scores(Variable): A 3-D Tensor with shape [N, C, M] represents the
predicted confidence predictions. N is the batch size, C is the
class number, M is number of bounding boxes. For each category
there are total M scores which corresponding M bounding boxes.
loc(Variable): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes. N is the batch size,
and each bounding box has four coordinate values and the layout
is [xmin, ymin, xmax, ymax].
scores(Variable): A 3-D Tensor with shape [N, M, C] represents the
predicted confidence predictions. N is the batch size, C is the
class number, M is number of bounding boxes. For each category
there are total M scores which corresponding M bounding boxes.
prior_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes,
each box is represented as [xmin, ymin, xmax, ymax],
[xmin, ymin] is the left top coordinate of the anchor box,
......@@ -100,7 +100,7 @@ def detection_output(scores,
append_batch_size=False, dtype='float32')
pbv = layers.data(name='prior_box_var', shape=[10, 4],
append_batch_size=False, dtype='float32')
loc = layers.data(name='target_box', shape=[21, 4],
loc = layers.data(name='target_box', shape=[2, 21, 4],
append_batch_size=False, dtype='float32')
scores = layers.data(name='scores', shape=[2, 21, 10],
append_batch_size=False, dtype='float32')
......@@ -109,7 +109,6 @@ def detection_output(scores,
prior_box=pb,
prior_box_var=pbv)
"""
helper = LayerHelper("detection_output", **locals())
decoded_box = box_coder(
prior_box=prior_box,
......@@ -118,6 +117,7 @@ def detection_output(scores,
code_type='decode_center_size')
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
scores = nn.transpose(scores, perm=[0, 2, 1])
helper.append_op(
type="multiclass_nms",
inputs={'Scores': scores,
......@@ -595,12 +595,13 @@ def multi_box_head(inputs,
name(str): Name of the prior box layer. Default: None.
Returns:
mbox_loc(list): The predicted boxes' location of the inputs.
The layout of each element is [N, H, W, Priors]. Priors
is the number of predicted boxof each position of each input.
mbox_conf(list): The predicted boxes' confidence of the inputs.
The layout of each element is [N, H, W, Priors]. Priors
is the number of predicted box of each position of each input.
mbox_loc(Variable): The predicted boxes' location of the inputs.
The layout is [N, H*W*Priors, 4]. where Priors
is the number of predicted boxes each position of each input.
mbox_conf(Variable): The predicted boxes' confidence of the inputs.
The layout is [N, H*W*Priors, C]. where Priors
is the number of predicted boxes each position of each input
and C is the number of Classes.
boxes(Variable): the output prior boxes of PriorBox.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs.
......@@ -751,7 +752,7 @@ def multi_box_head(inputs,
num_boxes = box.shape[2]
# get box_loc
num_loc_output = num_boxes * num_classes * 4
num_loc_output = num_boxes * 4
mbox_loc = nn.conv2d(
input=input,
num_filters=num_loc_output,
......@@ -760,7 +761,12 @@ def multi_box_head(inputs,
stride=stride)
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
mbox_locs.append(mbox_loc)
new_shape = [
mbox_loc.shape[0],
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4
]
mbox_loc_flatten = ops.reshape(mbox_loc, shape=new_shape)
mbox_locs.append(mbox_loc_flatten)
# get conf_loc
num_conf_output = num_boxes * num_classes
......@@ -771,11 +777,18 @@ def multi_box_head(inputs,
padding=pad,
stride=stride)
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
mbox_confs.append(conf_loc)
new_shape = [
conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] *
conf_loc.shape[3] / num_classes, num_classes
]
conf_loc_flatten = ops.reshape(conf_loc, shape=new_shape)
mbox_confs.append(conf_loc_flatten)
if len(box_results) == 1:
box = box_results[0]
var = var_results[0]
mbox_locs_concat = mbox_locs[0]
mbox_confs_concat = mbox_confs[0]
else:
reshaped_boxes = []
reshaped_vars = []
......@@ -785,5 +798,7 @@ def multi_box_head(inputs,
box = tensor.concat(reshaped_boxes)
var = tensor.concat(reshaped_vars)
mbox_locs_concat = tensor.concat(mbox_locs, axis=1)
mbox_confs_concat = tensor.concat(mbox_confs, axis=1)
return mbox_locs, mbox_confs, box, var
return mbox_locs_concat, mbox_confs_concat, box, var
......@@ -35,12 +35,12 @@ class TestDetection(unittest.TestCase):
dtype='float32')
loc = layers.data(
name='target_box',
shape=[20, 4],
shape=[2, 10, 4],
append_batch_size=False,
dtype='float32')
scores = layers.data(
name='scores',
shape=[2, 20, 10],
shape=[2, 10, 20],
append_batch_size=False,
dtype='float32')
out = layers.detection_output(
......@@ -117,9 +117,7 @@ class TestMultiBoxHead(unittest.TestCase):
assert len(box.shape) == 2
assert box.shape == var.shape
assert box.shape[1] == 4
for loc, conf in zip(mbox_locs, mbox_confs):
assert loc.shape[1:3] == conf.shape[1:3]
assert mbox_locs.shape[1] == mbox_confs.shape[1]
def multi_box_head_output(self, data_shape):
images = fluid.layers.data(
......
......@@ -51,8 +51,6 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
prior_box_var[:,:,3]
elif (code_type == "DecodeCenterSize"):
target_box = target_box.reshape(target_box.shape[0], 1,
target_box.shape[1])
target_box_x = prior_box_var[:,:,0] * target_box[:,:,0] * \
prior_box_width + prior_box_x
target_box_y = prior_box_var[:,:,1] * target_box[:,:,1] * \
......@@ -61,6 +59,7 @@ def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
prior_box_width
target_box_height = np.exp(prior_box_var[:,:,3] * target_box[:,:,3]) * \
prior_box_height
output_box[:, :, 0] = target_box_x - target_box_width / 2
output_box[:, :, 1] = target_box_y - target_box_height / 2
output_box[:, :, 2] = target_box_x + target_box_width / 2
......@@ -72,8 +71,14 @@ def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type):
m = prior_box.shape[0]
output_box = np.zeros((n, m, 4), dtype=np.float32)
for i in range(len(lod) - 1):
box_coder(target_box[lod[i]:lod[i + 1], :], prior_box, prior_box_var,
output_box[lod[i]:lod[i + 1], :, :], code_type)
if (code_type == "EncodeCenterSize"):
box_coder(target_box[lod[i]:lod[i + 1], :], prior_box,
prior_box_var, output_box[lod[i]:lod[i + 1], :, :],
code_type)
elif (code_type == "DecodeCenterSize"):
box_coder(target_box[lod[i]:lod[i + 1], :, :], prior_box,
prior_box_var, output_box[lod[i]:lod[i + 1], :, :],
code_type)
return output_box
......@@ -83,10 +88,10 @@ class TestBoxCoderOp(OpTest):
def setUp(self):
self.op_type = "box_coder"
lod = [[0, 20]]
lod = [[0, 1, 2, 3, 4, 5]]
prior_box = np.random.random((10, 4)).astype('float32')
prior_box_var = np.random.random((10, 4)).astype('float32')
target_box = np.random.random((20, 4)).astype('float32')
target_box = np.random.random((5, 10, 4)).astype('float32')
code_type = "DecodeCenterSize"
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
lod[0], code_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册