未验证 提交 e8f2614d 编写于 作者: G Guanghua Yu 提交者: GitHub

Enhance multiclass_nms op to support LoD for dygraph mode (#28276)

* Enhance multiclass_nms to support LoD for dygraph mode

* fix some error in multiclass_nms

* update GetLodFromRoisNum to GetNmsLodFromRoisNum
上级 842a4e5a
...@@ -21,6 +21,16 @@ namespace operators { ...@@ -21,6 +21,16 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
inline std::vector<size_t> GetNmsLodFromRoisNum(const Tensor* rois_num) {
std::vector<size_t> rois_lod;
auto* rois_num_data = rois_num->data<int>();
rois_lod.push_back(static_cast<size_t>(0));
for (int i = 0; i < rois_num->numel(); ++i) {
rois_lod.push_back(rois_lod.back() + static_cast<size_t>(rois_num_data[i]));
}
return rois_lod;
}
class MultiClassNMSOp : public framework::OperatorWithKernel { class MultiClassNMSOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -321,6 +331,8 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -321,6 +331,8 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
auto* outs = ctx.Output<LoDTensor>("Out"); auto* outs = ctx.Output<LoDTensor>("Out");
bool return_index = ctx.HasOutput("Index") ? true : false; bool return_index = ctx.HasOutput("Index") ? true : false;
auto index = ctx.Output<LoDTensor>("Index"); auto index = ctx.Output<LoDTensor>("Index");
bool has_roisnum = ctx.HasInput("RoisNum") ? true : false;
auto rois_num = ctx.Input<Tensor>("RoisNum");
auto score_dims = scores->dims(); auto score_dims = scores->dims();
auto score_size = score_dims.size(); auto score_size = score_dims.size();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
...@@ -332,7 +344,12 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -332,7 +344,12 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
int64_t out_dim = box_dim + 2; int64_t out_dim = box_dim + 2;
int num_nmsed_out = 0; int num_nmsed_out = 0;
Tensor boxes_slice, scores_slice; Tensor boxes_slice, scores_slice;
int n = score_size == 3 ? batch_size : boxes->lod().back().size() - 1; int n = 0;
if (has_roisnum) {
n = score_size == 3 ? batch_size : rois_num->numel();
} else {
n = score_size == 3 ? batch_size : boxes->lod().back().size() - 1;
}
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
std::map<int, std::vector<int>> indices; std::map<int, std::vector<int>> indices;
if (score_size == 3) { if (score_size == 3) {
...@@ -341,7 +358,12 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -341,7 +358,12 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
boxes_slice = boxes->Slice(i, i + 1); boxes_slice = boxes->Slice(i, i + 1);
boxes_slice.Resize({score_dims[2], box_dim}); boxes_slice.Resize({score_dims[2], box_dim});
} else { } else {
auto boxes_lod = boxes->lod().back(); std::vector<size_t> boxes_lod;
if (has_roisnum) {
boxes_lod = GetNmsLodFromRoisNum(rois_num);
} else {
boxes_lod = boxes->lod().back();
}
if (boxes_lod[i] == boxes_lod[i + 1]) { if (boxes_lod[i] == boxes_lod[i + 1]) {
all_indices.push_back(indices); all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back()); batch_starts.push_back(batch_starts.back());
...@@ -380,7 +402,12 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -380,7 +402,12 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
offset = i * score_dims[2]; offset = i * score_dims[2];
} }
} else { } else {
auto boxes_lod = boxes->lod().back(); std::vector<size_t> boxes_lod;
if (has_roisnum) {
boxes_lod = GetNmsLodFromRoisNum(rois_num);
} else {
boxes_lod = boxes->lod().back();
}
if (boxes_lod[i] == boxes_lod[i + 1]) continue; if (boxes_lod[i] == boxes_lod[i + 1]) continue;
scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]); scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]);
boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]); boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]);
...@@ -403,6 +430,15 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -403,6 +430,15 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
} }
} }
} }
if (ctx.HasOutput("NmsRoisNum")) {
auto* nms_rois_num = ctx.Output<Tensor>("NmsRoisNum");
nms_rois_num->mutable_data<int>({n}, ctx.GetPlace());
int* num_data = nms_rois_num->data<int>();
for (int i = 1; i <= n; i++) {
num_data[i - 1] = batch_starts[i] - batch_starts[i - 1];
}
nms_rois_num->Resize({n});
}
framework::LoD lod; framework::LoD lod;
lod.emplace_back(batch_starts); lod.emplace_back(batch_starts);
...@@ -535,6 +571,34 @@ class MultiClassNMS2OpMaker : public MultiClassNMSOpMaker { ...@@ -535,6 +571,34 @@ class MultiClassNMS2OpMaker : public MultiClassNMSOpMaker {
} }
}; };
class MultiClassNMS3Op : public MultiClassNMS2Op {
public:
MultiClassNMS3Op(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: MultiClassNMS2Op(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
MultiClassNMS2Op::InferShape(ctx);
ctx->SetOutputDim("NmsRoisNum", {-1});
}
};
class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
public:
void Make() override {
MultiClassNMS2OpMaker::Make();
AddInput("RoisNum",
"(Tensor) The number of RoIs in shape (B),"
"B is the number of images")
.AsDispensable();
AddOutput("NmsRoisNum", "(Tensor), The number of NMS RoIs in each image")
.AsDispensable();
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -551,3 +615,10 @@ REGISTER_OPERATOR( ...@@ -551,3 +615,10 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(multiclass_nms2, ops::MultiClassNMSKernel<float>, REGISTER_OP_CPU_KERNEL(multiclass_nms2, ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>); ops::MultiClassNMSKernel<double>);
REGISTER_OPERATOR(
multiclass_nms3, ops::MultiClassNMS3Op, ops::MultiClassNMS3OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(multiclass_nms3, ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>);
...@@ -52,6 +52,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -52,6 +52,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"hierarchical_sigmoid", {"hierarchical_sigmoid",
{"X", "W", "Label", "PathTable", "PathCode", "Bias"}}, {"X", "W", "Label", "PathTable", "PathCode", "Bias"}},
{"moving_average_abs_max_scale", {"X", "InAccum", "InState"}}, {"moving_average_abs_max_scale", {"X", "InAccum", "InState"}},
{"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
...@@ -78,6 +79,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -78,6 +79,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"distribute_fpn_proposals", {"distribute_fpn_proposals",
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}}, {"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}}, {"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
}; };
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......
...@@ -571,6 +571,128 @@ class TestMulticlassNMSError(unittest.TestCase): ...@@ -571,6 +571,128 @@ class TestMulticlassNMSError(unittest.TestCase):
self.assertRaises(TypeError, test_scores_Variable) self.assertRaises(TypeError, test_scores_Variable)
class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
def setUp(self):
self.set_argument()
N = 7
M = 1200
C = 21
BOX_SIZE = 4
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
score_threshold = self.score_threshold
scores = np.random.random((N * M, C)).astype('float32')
scores = np.apply_along_axis(softmax, 1, scores)
scores = np.reshape(scores, (N, M, C))
scores = np.transpose(scores, (0, 2, 1))
boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
det_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
det_outs = np.array(det_outs)
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
det_outs) else det_outs
index_outs = det_outs[:, -1:].astype('int') if len(
det_outs) else det_outs
self.op_type = 'multiclass_nms3'
self.inputs = {'BBoxes': boxes, 'Scores': scores}
self.outputs = {
'Out': (nmsed_outs, [lod]),
'Index': (index_outs, [lod]),
'NmsRoisNum': np.array(lod).astype('int32')
}
self.attrs = {
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
'normalized': True,
}
def test_check_output(self):
self.check_output()
class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
class TestMulticlassNMS3LoDInput(TestMulticlassNMS2LoDInput):
def setUp(self):
self.set_argument()
M = 1200
C = 21
BOX_SIZE = 4
box_lod = [[1200]]
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
score_threshold = self.score_threshold
normalized = False
scores = np.random.random((M, C)).astype('float32')
scores = np.apply_along_axis(softmax, 1, scores)
boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
boxes[:, :, 0] = boxes[:, :, 0] * 10
boxes[:, :, 1] = boxes[:, :, 1] * 10
boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10
det_outs, lod = lod_multiclass_nms(
boxes, scores, background, score_threshold, nms_threshold,
nms_top_k, keep_top_k, box_lod, normalized)
det_outs = np.array(det_outs)
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
det_outs) else det_outs
self.op_type = 'multiclass_nms3'
self.inputs = {
'BBoxes': (boxes, box_lod),
'Scores': (scores, box_lod),
'RoisNum': np.array(box_lod).astype('int32')
}
self.outputs = {
'Out': (nmsed_outs, [lod]),
'NmsRoisNum': np.array(lod).astype('int32')
}
self.attrs = {
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
'normalized': normalized,
}
def test_check_output(self):
self.check_output()
class TestMulticlassNMS3LoDNoOutput(TestMulticlassNMS3LoDInput):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册