提交 35dec3d7 编写于 作者: D dangqingqing

Fix bug in unit test.

上级 2731fd96
...@@ -41,13 +41,22 @@ class MulticlassNMSOp : public framework::OperatorWithKernel { ...@@ -41,13 +41,22 @@ class MulticlassNMSOp : public framework::OperatorWithKernel {
"The rank of Input(Bboxes) must be 3."); "The rank of Input(Bboxes) must be 3.");
PADDLE_ENFORCE_EQ(score_dims.size(), 3, PADDLE_ENFORCE_EQ(score_dims.size(), 3,
"The rank of Input(Scores) must be 3."); "The rank of Input(Scores) must be 3.");
PADDLE_ENFORCE_EQ(box_dims[2], 4); PADDLE_ENFORCE_EQ(box_dims[1], 4);
PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2]); PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2]);
// Here the box_dims[0] is not the real dimension of output. // Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel. // It will be rewritten in the computing kernel.
ctx->SetOutputDim("Out", {box_dims[0], 6}); ctx->SetOutputDim("Out", {box_dims[0], 6});
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("Scores")->type()),
ctx.device_context());
}
}; };
template <class T> template <class T>
...@@ -158,12 +167,12 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -158,12 +167,12 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
const Tensor& scores, const Tensor& bboxes, const Tensor& scores, const Tensor& bboxes,
std::map<int, std::vector<int>>* indices, std::map<int, std::vector<int>>* indices,
int* num_nmsed_out) const { int* num_nmsed_out) const {
int64_t background_label = ctx.Attr<int64_t>("background_label"); int64_t background_label = ctx.Attr<int>("background_label");
int64_t nms_top_k = ctx.Attr<int64_t>("nms_top_k"); int64_t nms_top_k = ctx.Attr<int>("nms_top_k");
int64_t keep_top_k = ctx.Attr<int64_t>("keep_top_k"); int64_t keep_top_k = ctx.Attr<int>("keep_top_k");
T nms_threshold = static_cast<T>(ctx.Attr<float>("nms_threshold")); T nms_threshold = static_cast<T>(ctx.Attr<float>("nms_threshold"));
T nms_eta = static_cast<T>(ctx.Attr<float>("nms_eta")); T nms_eta = static_cast<T>(ctx.Attr<float>("nms_eta"));
T score_threshold = static_cast<T>(ctx.Attr<float>("confidence_threshold")); T score_threshold = static_cast<T>(ctx.Attr<float>("score_threshold"));
int64_t class_num = scores.dims()[0]; int64_t class_num = scores.dims()[0];
int64_t predict_dim = scores.dims()[1]; int64_t predict_dim = scores.dims()[1];
...@@ -173,7 +182,7 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -173,7 +182,7 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
Tensor score = scores.Slice(c, c + 1); Tensor score = scores.Slice(c, c + 1);
NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k, NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k,
&((*indices)[c])); &((*indices)[c]));
num_det += indices[c].size(); num_det += (*indices)[c].size();
} }
*num_nmsed_out = num_det; *num_nmsed_out = num_det;
...@@ -230,8 +239,8 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -230,8 +239,8 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
odata[count * kOutputDim + 3] = bdata[1]; // ymin odata[count * kOutputDim + 3] = bdata[1]; // ymin
odata[count * kOutputDim + 4] = bdata[2]; // xmax odata[count * kOutputDim + 4] = bdata[2]; // xmax
odata[count * kOutputDim + 5] = bdata[3]; // ymax odata[count * kOutputDim + 5] = bdata[3]; // ymax
count++;
} }
count++;
} }
} }
...@@ -240,10 +249,9 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -240,10 +249,9 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
auto* scores = ctx.Input<Tensor>("Scores"); auto* scores = ctx.Input<Tensor>("Scores");
auto* outs = ctx.Output<LoDTensor>("Out"); auto* outs = ctx.Output<LoDTensor>("Out");
auto box_dims = boxes->dims();
auto score_dims = scores->dims(); auto score_dims = scores->dims();
int64_t batch_size = box_dims[0]; int64_t batch_size = score_dims[0];
int64_t class_num = score_dims[1]; int64_t class_num = score_dims[1];
int64_t predict_dim = score_dims[2]; int64_t predict_dim = score_dims[2];
...@@ -291,35 +299,37 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -291,35 +299,37 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) A 2-D Tensor with shape [M, 4] represents the location " "(Tensor) A 2-D Tensor with shape [M, 4] represents the location "
"predictions with M bboxes. 4 is the number of " "predictions with M bboxes. 4 is the number of "
"each location coordinates."); "each location coordinates.");
AddOutput("Scores", AddInput("Scores",
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the " "(Tensor) A 3-D Tensor with shape [N, C, M] represents the "
"confidence predictions. N is the batch size, C is the class " "confidence predictions. N is the batch size, C is the class "
"number, M is number of predictions for each class, which is " "number, M is number of predictions for each class, which is "
"the same with Bboxes."); "the same with Bboxes.");
AddAttr<int64_t>( AddAttr<int>(
"background_label", "background_label",
"(int64_t, defalut: 0) " "(int64_t, defalut: 0) "
"The index of background label, the background label will be ignored.") "The index of background label, the background label will be ignored.")
.SetDefault(0); .SetDefault(0);
AddAttr<float>("score_threshold",
"(float) "
"Only consider detections whose confidences are larger than "
"a threshold. If not provided, consider all boxes.");
AddAttr<int>("nms_top_k",
"(int64_t) "
"Maximum number of detections to be kept according to the "
"confidences aftern the filtering detections based on "
"score_threshold");
AddAttr<float>("nms_threshold", AddAttr<float>("nms_threshold",
"(float, defalut: 0.3) " "(float, defalut: 0.3) "
"The threshold to be used in nms.") "The threshold to be used in NMS.")
.SetDefault(0.3); .SetDefault(0.3);
AddAttr<int64_t>("nms_top_k",
"(int64_t) "
"Maximum number of results to be kept.");
AddAttr<float>("nms_eta", AddAttr<float>("nms_eta",
"(float) " "(float) "
"The parameter for adaptive nms.") "The parameter for adaptive NMS.")
.SetDefault(1.0); .SetDefault(1.0);
AddAttr<int64_t>("keep_top_k", AddAttr<int>("keep_top_k",
"(int64_t) " "(int64_t) "
"Number of total bboxes to be kept per image after nms " "Number of total bboxes to be kept per image after NMS "
"step. -1 means keeping all bboxes after nms step."); "step. -1 means keeping all bboxes after NMS step.");
AddAttr<float>("confidence_threshold",
"(float) "
"Only consider detections whose confidences are larger than "
"a threshold. If not provided, consider all boxes.");
AddOutput("Out", AddOutput("Out",
"(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the "
"detections. Each row has 6 values: " "detections. Each row has 6 values: "
...@@ -329,15 +339,21 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -329,15 +339,21 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected bbox."); "no detected bbox.");
AddComment(R"DOC( AddComment(R"DOC(
This operators is to do multi-class non maximum suppression (NMS) on a batched This operator is to do multi-class non maximum suppression (NMS) on a batched
of boxes and scores. of boxes and scores.
This op greedily selects a subset of detection bounding boxes, pruning In the NMS step, this operator greedily selects a subset of detection bounding
away boxes that have high IOU (intersection over union) overlap (> thresh) boxes that have high scores larger than score_threshold, if providing this
with already selected boxes. It operates independently for each class for threshold, then selects the largest nms_top_k confidences scores if nms_top_k
which scores are provided, pruning boxes with score less than a provided is larger than -1. Then this operator pruns away boxes that have high IOU
threshold prior to applying NMS. (intersection over union) overlap with already selected boxes by adaptive
threshold NMS based on parameters of nms_threshold and nms_eta.
Aftern NMS step, only at most keep_top_k number of total bboxes are to be kept
per image if keep_top_k is larger than -1.
This operator support multi-class and batched inputs. It applying NMS
independently for each class.
)DOC"); )DOC");
} }
}; };
......
...@@ -69,7 +69,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): ...@@ -69,7 +69,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
sorted_indices = np.argsort(-all_scores, axis=0) sorted_indices = np.argsort(-all_scores, axis=0)
sorted_scores = all_scores[sorted_indices] sorted_scores = all_scores[sorted_indices]
if top_k < -1 and top_k < sorted_indices.shape[0]: if top_k > -1 and top_k < sorted_indices.shape[0]:
sorted_indices = sorted_indices[:top_k] sorted_indices = sorted_indices[:top_k]
sorted_scores = sorted_scores[:top_k] sorted_scores = sorted_scores[:top_k]
...@@ -82,7 +82,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): ...@@ -82,7 +82,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
if keep: if keep:
kept_idx = selected_indices[k] kept_idx = selected_indices[k]
overlap = iou(boxes[idx], boxes[kept_idx]) overlap = iou(boxes[idx], boxes[kept_idx])
keep = overlap <= adaptive_threshold keep = True if overlap <= adaptive_threshold else False
else: else:
break break
if keep: if keep:
...@@ -103,14 +103,14 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, ...@@ -103,14 +103,14 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
if c == background: continue if c == background: continue
indices = nms(boxes, scores[c], score_threshold, nms_threshold, indices = nms(boxes, scores[c], score_threshold, nms_threshold,
nms_top_k) nms_top_k)
selected_indices.append((c, indices)) for idx in indices:
selected_indices.append((c, idx))
num_det += len(indices) num_det += len(indices)
if keep_top_k > -1 and num_det > keep_top_k: if keep_top_k > -1 and num_det > keep_top_k:
score_index = [] score_index = []
for c, indices in selected_indices: for c, idx in selected_indices:
for idx in indices: score_index.append((scores[c][idx], c, idx))
score_index.append((scores[c][idx], c, idx))
sorted_score_index = sorted( sorted_score_index = sorted(
score_index, key=lambda tup: tup[0], reverse=True) score_index, key=lambda tup: tup[0], reverse=True)
...@@ -134,19 +134,16 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold, ...@@ -134,19 +134,16 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold,
keep_top_k) keep_top_k)
lod.append(lod[-1] + len(nmsed_outs)) lod.append(lod[-1] + len(nmsed_outs))
if len(nmsed_outs) == 0: continue if len(nmsed_outs) == 0: continue
for c, indices in nmsed_outs: for c, idx in nmsed_outs:
for idx in indices: xmin, ymin, xmax, ymax = boxes[idx][:]
xmin, ymin, xmax, ymax = boxes[idx][:] det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax])
det_outs.append(
(c, scores[n][c][idx], c, xmin, ymin, xmax, ymax))
return det_outs, lod return det_outs, lod
class TestMulticlassNMSOp(OpTest): class TestMulticlassNMSOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'multiclass_nms'
N = 7 N = 7
M = 1230 M = 1240
C = 21 C = 21
BOX_SIZE = 4 BOX_SIZE = 4
background = 0 background = 0
...@@ -155,7 +152,17 @@ class TestMulticlassNMSOp(OpTest): ...@@ -155,7 +152,17 @@ class TestMulticlassNMSOp(OpTest):
keep_top_k = 200 keep_top_k = 200
score_threshold = 0.01 score_threshold = 0.01
scores = np.random.random((N, C, M)).astype('float32') scores = np.random.random((N * M, C)).astype('float32')
def softmax(x):
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
scores = np.apply_along_axis(softmax, 1, scores)
scores = np.reshape(scores, (N, M, C))
scores = np.transpose(scores, (0, 2, 1))
boxes = np.random.random((M, BOX_SIZE)).astype('float32') boxes = np.random.random((M, BOX_SIZE)).astype('float32')
boxes[:, 0:2] = boxes[:, 0:2] * 0.5 boxes[:, 0:2] = boxes[:, 0:2] * 0.5
boxes[:, 2:4] = boxes[:, 0:2] * 0.5 + 0.5 boxes[:, 2:4] = boxes[:, 0:2] * 0.5 + 0.5
...@@ -163,8 +170,19 @@ class TestMulticlassNMSOp(OpTest): ...@@ -163,8 +170,19 @@ class TestMulticlassNMSOp(OpTest):
nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background, nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold, score_threshold, nms_threshold,
nms_top_k, keep_top_k) nms_top_k, keep_top_k)
nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'multiclass_nms'
self.inputs = {'Bboxes': boxes, 'Scores': scores} self.inputs = {'Bboxes': boxes, 'Scores': scores}
self.outputs = {'Out': (nmsed_outs, [lod])} self.outputs = {'Out': (nmsed_outs, [lod])}
self.attrs = {
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -182,18 +200,3 @@ class TestIOU(unittest.TestCase): ...@@ -182,18 +200,3 @@ class TestIOU(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# N = 7
# M = 8
# C = 5
# BOX_SIZE = 4
# background = 0
# nms_threshold = 0.3
# nms_top_k = 400
# keep_top_k = 200
# score_threshold = 0.5
# scores = np.random.random((N, C, M)).astype('float32')
# boxes = np.random.random((M, BOX_SIZE)).astype('float32')
# boxes[:, 0 : 2] = boxes[:, 0 : 2] * 0.5
# boxes[:, 2 : 4] = boxes[:, 0 : 2] * 0.5 + 0.5
# print nmsed_outs, lod
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册