未验证 提交 25dcd74d 编写于 作者: W wangguanzhong 提交者: GitHub

merge empty lod tensor, test=develop (#19228)

* merge_empty_lod_tensor, test=develop

* fix multiclass_nms, test=develop

* refine API.spec, test=develop

* add unittest case for fetch, test=develop

* add lod tensor test, test=develop

* return index for multiclass_nms, test=develop

* add api for multiclass_nms2

* update API.spc, test=develop

* refine api doc, test=develop

* fix test_detection.py, test=develop

* polish code, test=develop

* add more unittest case, test=develop
上级 c6756ed2
......@@ -396,7 +396,7 @@ paddle.fluid.layers.density_prior_box (ArgSpec(args=['input', 'image', 'densitie
paddle.fluid.layers.multi_box_head (ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False)), ('document', 'fd58078fdfffd899b91f992ba224628f'))
paddle.fluid.layers.bipartite_match (ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '080ce0d54d3f1950ad5a3a8e5ae529e9'))
paddle.fluid.layers.target_assign (ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'e9685f32d21bec8c013626c0254502c5'))
paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0)), ('document', 'efae414c1137c7944d6174dd08c5347a'))
paddle.fluid.layers.detection_output (ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta', 'return_index'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0, False)), ('document', '5485bcaceb0cde2695565a2ffd5bbd40'))
paddle.fluid.layers.ssd_loss (ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)), ('document', '8edacd4b9bd02dd68931b9fa6bfe0cbd'))
paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '651d98d51879dfa1bc1cd40391786a41'))
paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595'))
......@@ -412,7 +412,8 @@ paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varar
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gt_box', 'gt_label', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'gt_score', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '400403175718d5a632402cdae88b01b8'))
paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ed56ff21536ca5c8ad418d0cfaf6a7b9'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '9ddee76cb808db83768bf68010e39b2b'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', '51a388c4d067ea93a6a60492db40c7af'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'f6e333d76922c6e564413b4d216c245c'))
paddle.fluid.layers.multiclass_nms2 (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'return_index', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, False, None)), ('document', 'be156186ee7a2ee56ab30b964acb15e5'))
paddle.fluid.layers.retinanet_detection_output (ArgSpec(args=['bboxes', 'scores', 'anchors', 'im_info', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0.05, 1000, 100, 0.3, 1.0)), ('document', '078d28607ce261a0cba2b965a79f6bb8'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6c023b9401214ae387a8b2d92638e5e4'))
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '3619a7847709f5868f5e929065947b38'))
......
......@@ -61,12 +61,17 @@ void FetchOpHandle::RunImpl() {
var_handle->name());
auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) {
if (t.IsInitialized() && t.numel() > 0) {
if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, &tensors_[i]);
TensorCopy(t, cpu, &tensors_[i]);
#endif
} else {
tensors_[i].ShareDataWith(t);
}
} else {
tensors_[i].ShareDataWith(t);
tensors_[i].clear();
tensors_[i].Resize({0});
}
tensors_[i].set_lod(t.lod());
}
......
......@@ -326,17 +326,28 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims();
auto new_type = lod_tensors[0]->type();
proto::VarType::Type new_type = proto::VarType::FP32;
framework::DataLayout new_layout = lod_tensors[0]->layout();
for (auto *t : lod_tensors) {
if (t->numel() && t->IsInitialized()) {
new_dim = t->dims();
new_type = t->type();
new_layout = t->layout();
break;
}
}
LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) {
auto *t = lod_tensors[i];
PADDLE_ENFORCE_EQ(new_type, t->type());
PADDLE_ENFORCE_EQ(new_layout, t->layout());
PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0],
framework::product(t->dims()) / t->dims()[0]);
new_dim[0] += t->dims()[0];
if (t->numel() && t->IsInitialized()) {
PADDLE_ENFORCE_EQ(new_type, t->type());
PADDLE_ENFORCE_EQ(new_layout, t->layout());
PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0],
framework::product(t->dims()) / t->dims()[0]);
new_dim[0] += t->dims()[0];
}
auto &lod = t->lod();
PADDLE_ENFORCE_EQ(new_lod.size(), lod.size());
......@@ -356,6 +367,9 @@ void LoDTensor::MergeLoDTensor(
int begin = 0;
for (auto *src : lod_tensors) {
int end = begin + src->dims()[0];
if (end == begin) {
continue;
}
auto dst = Slice(begin, end);
framework::TensorCopy(*src, dst_place, &dst);
begin = end;
......
......@@ -185,7 +185,15 @@ TEST(LoD, MergeLoDTensor) {
dst_ptr[i] = i;
}
std::vector<const LoDTensor*> lods{&lod_tensor0, &lod_tensor1};
LoDTensor lod_tensor2;
LoD lod2;
lod2.push_back(std::vector<size_t>({0}));
lod2.push_back(std::vector<size_t>({0}));
lod_tensor2.set_lod(lod2);
lod_tensor2.Resize({0});
dst_ptr = lod_tensor2.mutable_data<float>(place);
std::vector<const LoDTensor*> lods{&lod_tensor0, &lod_tensor1, &lod_tensor2};
LoDTensor lod_tensor;
lod_tensor.MergeLoDTensor(lods, place);
......
......@@ -328,7 +328,8 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
void MultiClassOutput(const platform::DeviceContext& ctx,
const Tensor& scores, const Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices,
const int scores_size, Tensor* outs) const {
const int scores_size, Tensor* outs,
int* oindices = nullptr, const int offset = 0) const {
int64_t class_num = scores.dims()[1];
int64_t predict_dim = scores.dims()[1];
int64_t box_size = bboxes.dims()[1];
......@@ -358,9 +359,15 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
if (scores_size == 3) {
bdata = bboxes_data + idx * box_size;
odata[count * out_dim + 1] = sdata[idx]; // score
if (oindices != nullptr) {
oindices[count] = offset + idx;
}
} else {
bdata = bbox.data<T>() + idx * box_size;
odata[count * out_dim + 1] = *(scores_data + idx * class_num + label);
if (oindices != nullptr) {
oindices[count] = offset + idx * class_num + label;
}
}
// xmin, ymin, xmax, ymax or multi-points coordinates
std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
......@@ -373,7 +380,8 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
auto* boxes = ctx.Input<LoDTensor>("BBoxes");
auto* scores = ctx.Input<LoDTensor>("Scores");
auto* outs = ctx.Output<LoDTensor>("Out");
bool return_index = ctx.HasOutput("Index") ? true : false;
auto index = ctx.Output<LoDTensor>("Index");
auto score_dims = scores->dims();
auto score_size = score_dims.size();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
......@@ -406,35 +414,55 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
int num_kept = batch_starts.back();
if (num_kept == 0) {
T* od = outs->mutable_data<T>({1, 1}, ctx.GetPlace());
od[0] = -1;
batch_starts = {0, 1};
if (return_index) {
outs->mutable_data<T>({0, out_dim}, ctx.GetPlace());
index->mutable_data<int>({0, 1}, ctx.GetPlace());
} else {
T* od = outs->mutable_data<T>({1, 1}, ctx.GetPlace());
od[0] = -1;
batch_starts = {0, 1};
}
} else {
outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace());
int offset = 0;
int* oindices = nullptr;
for (int i = 0; i < n; ++i) {
if (score_size == 3) {
scores_slice = scores->Slice(i, i + 1);
boxes_slice = boxes->Slice(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice.Resize({score_dims[2], box_dim});
if (return_index) {
offset = i * score_dims[2];
}
} else {
auto boxes_lod = boxes->lod().back();
scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]);
boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]);
if (return_index) {
offset = boxes_lod[i] * score_dims[1];
}
}
int64_t s = batch_starts[i];
int64_t e = batch_starts[i + 1];
if (e > s) {
Tensor out = outs->Slice(s, e);
if (return_index) {
int* output_idx =
index->mutable_data<int>({num_kept, 1}, ctx.GetPlace());
oindices = output_idx + s;
}
MultiClassOutput(dev_ctx, scores_slice, boxes_slice, all_indices[i],
score_dims.size(), &out);
score_dims.size(), &out, oindices, offset);
}
}
}
framework::LoD lod;
lod.emplace_back(batch_starts);
if (return_index) {
index->set_lod(lod);
}
outs->set_lod(lod);
}
};
......@@ -519,13 +547,45 @@ This operator support multi-class and batched inputs. It applying NMS
independently for each class. The outputs is a 2-D LoDTenosr, for each
image, the offsets in first dimension of LoDTensor are called LoD, the number
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
means there is no detected bbox for this image. If there is no detected boxes
for all images, all the elements in LoD are set to {1}, and the Out only
contains one value which is -1.
means there is no detected bbox for this image.
)DOC");
}
};
class MultiClassNMS2Op : public MultiClassNMSOp {
public:
MultiClassNMS2Op(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: MultiClassNMSOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
MultiClassNMSOp::InferShape(ctx);
auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores");
auto score_size = score_dims.size();
if (score_size == 3) {
ctx->SetOutputDim("Index", {box_dims[1], 1});
} else {
ctx->SetOutputDim("Index", {-1, 1});
}
}
};
class MultiClassNMS2OpMaker : public MultiClassNMSOpMaker {
public:
void Make() override {
MultiClassNMSOpMaker::Make();
AddOutput("Index",
"(LoDTensor) A 2-D LoDTensor with shape [No, 1] represents the "
"index of selected bbox. The index is the absolute index cross "
"batches.")
.AsIntermediate();
}
};
} // namespace operators
} // namespace paddle
......@@ -535,3 +595,8 @@ REGISTER_OPERATOR(multiclass_nms, ops::MultiClassNMSOp,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>);
REGISTER_OPERATOR(multiclass_nms2, ops::MultiClassNMS2Op,
ops::MultiClassNMS2OpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(multiclass_nms2, ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>);
......@@ -53,6 +53,7 @@ __all__ = [
'yolo_box',
'box_clip',
'multiclass_nms',
'multiclass_nms2',
'retinanet_detection_output',
'distribute_fpn_proposals',
'box_decoder_and_assign',
......@@ -446,7 +447,8 @@ def detection_output(loc,
nms_top_k=400,
keep_top_k=200,
score_threshold=0.01,
nms_eta=1.0):
nms_eta=1.0,
return_index=False):
"""
**Detection Output Layer for Single Shot Multibox Detector (SSD).**
......@@ -489,21 +491,32 @@ def detection_output(loc,
score_threshold(float): Threshold to filter out bounding boxes with
low confidence score. If not provided, consider all boxes.
nms_eta(float): The parameter for adaptive NMS.
return_index(bool): Whether return selected index. Default: False
Returns:
Variable:
The detection outputs is a LoDTensor with shape [No, 6].
Each row has six values: [label, confidence, xmin, ymin, xmax, ymax].
`No` is the total number of detections in this mini-batch. For each
instance, the offsets in first dimension are called LoD, the offset
number is N + 1, N is the batch size. The i-th image has
`LoD[i + 1] - LoD[i]` detected results, if it is 0, the i-th image
has no detected results. If all images have not detected results,
LoD will be set to {1}, and output tensor only contains one
value, which is -1.
(After version 1.3, when no boxes detected, the lod is changed
from {0} to {1}.)
A tuple with two Variables: (Out, Index) if return_index is True,
otherwise, a tuple with one Variable(Out) is returned.
Out: The detection outputs is a LoDTensor with shape [No, 6]. Each row
has six values: [label, confidence, xmin, ymin, xmax, ymax]. `No` is
the total number of detections in this mini-batch. For each instance,
the offsets in first dimension are called LoD, the offset number is
N + 1, N is the batch size. The i-th image has `LoD[i + 1] - LoD[i]`
detected results, if it is 0, the i-th image has no detected results.
If all images have not detected results, LoD will be set to {1}, and
output tensor only contains one value, which is -1.
(After version 1.3, when no boxes detected, the lod is changed
from {0} to {1}.)
Index: Only return when return_index is True. A 2-D LoDTensor with
shape [No, 1] represents the selected index which type is Integer.
The index is the absolute value cross batches. No is the same number
as Out. If the index is used to gather other attribute such as age,
one needs to reshape the input(N, M, 1) to (N * M, 1) as first, where
N is the batch size and M is the number of boxes.
Examples:
.. code-block:: python
......@@ -518,10 +531,11 @@ def detection_output(loc,
append_batch_size=False, dtype='float32')
scores = fluid.layers.data(name='scores', shape=[2, 21, 10],
append_batch_size=False, dtype='float32')
nmsed_outs = fluid.layers.detection_output(scores=scores,
nmsed_outs, index = fluid.layers.detection_output(scores=scores,
loc=loc,
prior_box=pb,
prior_box_var=pbv)
prior_box_var=pbv,
return_index=True)
"""
helper = LayerHelper("detection_output", **locals())
decoded_box = box_coder(
......@@ -534,20 +548,40 @@ def detection_output(loc,
scores.stop_gradient = True
nmsed_outs = helper.create_variable_for_type_inference(
dtype=decoded_box.dtype)
helper.append_op(
type="multiclass_nms",
inputs={'Scores': scores,
'BBoxes': decoded_box},
outputs={'Out': nmsed_outs},
attrs={
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0
})
if return_index:
index = helper.create_variable_for_type_inference(dtype='int')
helper.append_op(
type="multiclass_nms2",
inputs={'Scores': scores,
'BBoxes': decoded_box},
outputs={'Out': nmsed_outs,
'Index': index},
attrs={
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
})
index.stop_gradient = True
else:
helper.append_op(
type="multiclass_nms",
inputs={'Scores': scores,
'BBoxes': decoded_box},
outputs={'Out': nmsed_outs},
attrs={
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
})
nmsed_outs.stop_gradient = True
if return_index:
return nmsed_outs, index
return nmsed_outs
......@@ -2690,7 +2724,6 @@ def multiclass_nms(bboxes,
is larger than -1. Then this operator pruns away boxes that have high IOU
(intersection over union) overlap with already selected boxes by adaptive
threshold NMS based on parameters of nms_threshold and nms_eta.
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
per image if keep_top_k is larger than -1.
......@@ -2708,7 +2741,7 @@ def multiclass_nms(bboxes,
nms_threshold = 0.3
background_label = 0
score_threshold = 0
Then:
iou = 4/11 > 0.3
......@@ -2809,6 +2842,141 @@ def multiclass_nms(bboxes,
return output
def multiclass_nms2(bboxes,
scores,
score_threshold,
nms_top_k,
keep_top_k,
nms_threshold=0.3,
normalized=True,
nms_eta=1.,
background_label=0,
return_index=False,
name=None):
"""
**Multiclass NMS2**
This operator is to do multi-class non maximum suppression (NMS) on
boxes and scores.
In the NMS step, this operator greedily selects a subset of detection bounding
boxes that have high scores larger than score_threshold, if providing this
threshold, then selects the largest nms_top_k confidences scores if nms_top_k
is larger than -1. Then this operator pruns away boxes that have high IOU
(intersection over union) overlap with already selected boxes by adaptive
threshold NMS based on parameters of nms_threshold and nms_eta.
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
per image if keep_top_k is larger than -1.
Args:
bboxes (Variable): Two types of bboxes are supported:
1. (Tensor) A 3-D Tensor with shape
[N, M, 4 or 8 16 24 32] represents the
predicted locations of M bounding bboxes,
N is the batch size. Each bounding box has four
coordinate values and the layout is
[xmin, ymin, xmax, ymax], when box size equals to 4.
2. (LoDTensor) A 3-D Tensor with shape [M, C, 4]
M is the number of bounding boxes, C is the
class number
scores (Variable): Two types of scores are supported:
1. (Tensor) A 3-D Tensor with shape [N, C, M]
represents the predicted confidence predictions.
N is the batch size, C is the class number, M is
number of bounding boxes. For each category there
are total M scores which corresponding M bounding
boxes. Please note, M is equal to the 2nd dimension
of BBoxes.
2. (LoDTensor) A 2-D LoDTensor with shape [M, C].
M is the number of bbox, C is the class number.
In this case, input BBoxes should be the second
case with shape [M, C, 4].
background_label (int): The index of background label, the background
label will be ignored. If set to -1, then all
categories will be considered. Default: 0
score_threshold (float): Threshold to filter out bounding boxes with
low confidence score. If not provided,
consider all boxes.
nms_top_k (int): Maximum number of detections to be kept according to
the confidences aftern the filtering detections based
on score_threshold.
nms_threshold (float): The threshold to be used in NMS. Default: 0.3
nms_eta (float): The threshold to be used in NMS. Default: 1.0
keep_top_k (int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step.
normalized (bool): Whether detections are normalized. Default: True
return_index(bool): Whether return selected index. Default: False
name(str): Name of the multiclass nms op. Default: None.
Returns:
A tuple with two Variables: (Out, Index) if return_index is True,
otherwise, a tuple with one Variable(Out) is returned.
Out: A 2-D LoDTensor with shape [No, 6] represents the detections.
Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax]
or A 2-D LoDTensor with shape [No, 10] represents the detections.
Each row has 10 values: [label, confidence, x1, y1, x2, y2, x3, y3,
x4, y4]. No is the total number of detections.
If all images have not detected results, all elements in LoD will be
0, and output tensor is empty (None).
Index: Only return when return_index is True. A 2-D LoDTensor with
shape [No, 1] represents the selected index which type is Integer.
The index is the absolute value cross batches. No is the same number
as Out. If the index is used to gather other attribute such as age,
one needs to reshape the input(N, M, 1) to (N * M, 1) as first, where
N is the batch size and M is the number of boxes.
Examples:
.. code-block:: python
import paddle.fluid as fluid
boxes = fluid.layers.data(name='bboxes', shape=[81, 4],
dtype='float32', lod_level=1)
scores = fluid.layers.data(name='scores', shape=[81],
dtype='float32', lod_level=1)
out, index = fluid.layers.multiclass_nms2(bboxes=boxes,
scores=scores,
background_label=0,
score_threshold=0.5,
nms_top_k=400,
nms_threshold=0.3,
keep_top_k=200,
normalized=False,
return_index=True)
"""
helper = LayerHelper('multiclass_nms2', **locals())
output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
index = helper.create_variable_for_type_inference(dtype='int')
helper.append_op(
type="multiclass_nms2",
inputs={'BBoxes': bboxes,
'Scores': scores},
attrs={
'background_label': background_label,
'score_threshold': score_threshold,
'nms_top_k': nms_top_k,
'nms_threshold': nms_threshold,
'nms_eta': nms_eta,
'keep_top_k': keep_top_k,
'nms_eta': nms_eta,
'normalized': normalized
},
outputs={'Out': output,
'Index': index})
output.stop_gradient = True
index.stop_gradient = True
if return_index:
return output, index
return output
def distribute_fpn_proposals(fpn_rois,
min_level,
max_level,
......
......@@ -47,7 +47,15 @@ class TestDetection(unittest.TestCase):
dtype='float32')
out = layers.detection_output(
scores=scores, loc=loc, prior_box=pb, prior_box_var=pbv)
out2, index = layers.detection_output(
scores=scores,
loc=loc,
prior_box=pb,
prior_box_var=pbv,
return_index=True)
self.assertIsNotNone(out)
self.assertIsNotNone(out2)
self.assertIsNotNone(index)
self.assertEqual(out.shape[-1], 6)
print(str(program))
......@@ -523,6 +531,21 @@ class TestMulticlassNMS(unittest.TestCase):
self.assertIsNotNone(output)
class TestMulticlassNMS2(unittest.TestCase):
def test_multiclass_nms2(self):
program = Program()
with program_guard(program):
bboxes = layers.data(
name='bboxes', shape=[-1, 10, 4], dtype='float32')
scores = layers.data(name='scores', shape=[-1, 10], dtype='float32')
output = layers.multiclass_nms2(bboxes, scores, 0.3, 400, 200, 0.7)
output2, index = layers.multiclass_nms2(
bboxes, scores, 0.3, 400, 200, 0.7, return_index=True)
self.assertIsNotNone(output)
self.assertIsNotNone(output2)
self.assertIsNotNone(index)
class TestCollectFpnPropsals(unittest.TestCase):
def test_collect_fpn_proposals(self):
program = Program()
......
......@@ -22,17 +22,25 @@ import unittest
class TestFetchVar(op_test.OpTest):
def set_input(self):
self.val = numpy.array([1, 3, 5]).astype(numpy.int32)
def test_fetch_var(self):
val = numpy.array([1, 3, 5]).astype(numpy.int32)
self.set_input()
x = layers.create_tensor(dtype="int32", persistable=True, name="x")
layers.assign(input=val, output=x)
layers.assign(input=self.val, output=x)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_main_program(), feed={}, fetch_list=[])
fetched_x = fluid.executor._fetch_var("x")
self.assertTrue(
numpy.array_equal(fetched_x, val),
"fetch_x=%s val=%s" % (fetched_x, val))
self.assertEqual(fetched_x.dtype, val.dtype)
numpy.array_equal(fetched_x, self.val),
"fetch_x=%s val=%s" % (fetched_x, self.val))
self.assertEqual(fetched_x.dtype, self.val.dtype)
class TestFetchNullVar(TestFetchVar):
def set_input(self):
self.val = numpy.array([]).astype(numpy.int32)
if __name__ == '__main__':
......
......@@ -156,12 +156,14 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
def lod_multiclass_nms(boxes, scores, background, score_threshold,
nms_threshold, nms_top_k, keep_top_k, box_lod,
normalized):
num_class = boxes.shape[1]
det_outs = []
lod = []
head = 0
for n in range(len(box_lod[0])):
box = boxes[head:head + box_lod[0][n]]
score = scores[head:head + box_lod[0][n]]
offset = head
head = head + box_lod[0][n]
nmsed_outs, nmsed_num = multiclass_nms(
box,
......@@ -173,19 +175,21 @@ def lod_multiclass_nms(boxes, scores, background, score_threshold,
keep_top_k,
normalized,
shared=False)
lod.append(nmsed_num)
if nmsed_num == 0:
continue
lod.append(nmsed_num)
tmp_det_out = []
for c, indices in nmsed_outs.items():
for idx in indices:
xmin, ymin, xmax, ymax = box[idx, c, :]
tmp_det_out.append([c, score[idx][c], xmin, ymin, xmax, ymax])
tmp_det_out.append([
c, score[idx][c], xmin, ymin, xmax, ymax,
offset * num_class + idx * num_class + c
])
sorted_det_out = sorted(
tmp_det_out, key=lambda tup: tup[0], reverse=False)
det_outs.extend(sorted_det_out)
if len(lod) == 0:
lod.append(1)
return det_outs, lod
......@@ -199,8 +203,9 @@ def batched_multiclass_nms(boxes,
keep_top_k,
normalized=True):
batch_size = scores.shape[0]
num_boxes = scores.shape[2]
det_outs = []
index_outs = []
lod = []
for n in range(batch_size):
nmsed_outs, nmsed_num = multiclass_nms(
......@@ -213,21 +218,21 @@ def batched_multiclass_nms(boxes,
keep_top_k,
normalized,
shared=True)
lod.append(nmsed_num)
if nmsed_num == 0:
continue
lod.append(nmsed_num)
tmp_det_out = []
for c, indices in nmsed_outs.items():
for idx in indices:
xmin, ymin, xmax, ymax = boxes[n][idx][:]
tmp_det_out.append(
[c, scores[n][c][idx], xmin, ymin, xmax, ymax])
tmp_det_out.append([
c, scores[n][c][idx], xmin, ymin, xmax, ymax,
idx + n * num_boxes
])
sorted_det_out = sorted(
tmp_det_out, key=lambda tup: tup[0], reverse=False)
det_outs.extend(sorted_det_out)
if len(lod) == 0:
lod += [1]
return det_outs, lod
......@@ -262,11 +267,13 @@ class TestMulticlassNMSOp(OpTest):
boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
nmsed_outs = np.array(nmsed_outs).astype('float32')
det_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
lod = [1] if not det_outs else lod
det_outs = [[-1, 0]] if not det_outs else det_outs
det_outs = np.array(det_outs)
nmsed_outs = det_outs[:, :-1].astype('float32')
self.op_type = 'multiclass_nms'
self.inputs = {'BBoxes': boxes, 'Scores': scores}
......@@ -324,11 +331,12 @@ class TestMulticlassNMSLoDInput(OpTest):
boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10
nmsed_outs, lod = lod_multiclass_nms(
det_outs, lod = lod_multiclass_nms(
boxes, scores, background, score_threshold, nms_threshold,
nms_top_k, keep_top_k, box_lod, normalized)
nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
nmsed_outs = np.array(nmsed_outs).astype('float32')
det_outs = np.array(det_outs).astype('float32')
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
det_outs) else det_outs
self.op_type = 'multiclass_nms'
self.inputs = {
'BBoxes': (boxes, box_lod),
......@@ -359,5 +367,137 @@ class TestIOU(unittest.TestCase):
self.assertTrue(np.allclose(calc_output, expt_output))
class TestMulticlassNMS2Op(TestMulticlassNMSOp):
def setUp(self):
self.set_argument()
N = 7
M = 1200
C = 21
BOX_SIZE = 4
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
score_threshold = self.score_threshold
scores = np.random.random((N * M, C)).astype('float32')
def softmax(x):
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
scores = np.apply_along_axis(softmax, 1, scores)
scores = np.reshape(scores, (N, M, C))
scores = np.transpose(scores, (0, 2, 1))
boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
det_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
det_outs = np.array(det_outs)
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
det_outs) else det_outs
index_outs = det_outs[:, -1:].astype('int') if len(
det_outs) else det_outs
self.op_type = 'multiclass_nms2'
self.inputs = {'BBoxes': boxes, 'Scores': scores}
self.outputs = {
'Out': (nmsed_outs, [lod]),
'Index': (index_outs, [lod])
}
self.attrs = {
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
'normalized': True,
}
def test_check_output(self):
self.check_output()
class TestMulticlassNMS2OpNoOutput(TestMulticlassNMS2Op):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput):
def setUp(self):
self.set_argument()
M = 1200
C = 21
BOX_SIZE = 4
box_lod = [[1200]]
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
score_threshold = self.score_threshold
normalized = False
scores = np.random.random((M, C)).astype('float32')
def softmax(x):
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
scores = np.apply_along_axis(softmax, 1, scores)
boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
boxes[:, :, 0] = boxes[:, :, 0] * 10
boxes[:, :, 1] = boxes[:, :, 1] * 10
boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10
det_outs, lod = lod_multiclass_nms(
boxes, scores, background, score_threshold, nms_threshold,
nms_top_k, keep_top_k, box_lod, normalized)
det_outs = np.array(det_outs)
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
det_outs) else det_outs
index_outs = det_outs[:, -1:].astype('int') if len(
det_outs) else det_outs
self.op_type = 'multiclass_nms2'
self.inputs = {
'BBoxes': (boxes, box_lod),
'Scores': (scores, box_lod),
}
self.outputs = {
'Out': (nmsed_outs, [lod]),
'Index': (index_outs, [lod])
}
self.attrs = {
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
'normalized': normalized,
}
def test_check_output(self):
self.check_output()
class TestMulticlassNMS2LoDNoOutput(TestMulticlassNMS2LoDInput):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册