Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1c6d0646
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1c6d0646
编写于
5月 14, 2019
作者:
J
jerrywgz
提交者:
GitHub
5月 14, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add collect fpn proposals op,test=develop (#16074)
* add collect fpn proposals op,test=develop
上级
60be66e2
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
667 addition
and
4 deletion
+667
-4
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/detection/CMakeLists.txt
paddle/fluid/operators/detection/CMakeLists.txt
+2
-0
paddle/fluid/operators/detection/bbox_util.h
paddle/fluid/operators/detection/bbox_util.h
+4
-4
paddle/fluid/operators/detection/collect_fpn_proposals_op.cc
paddle/fluid/operators/detection/collect_fpn_proposals_op.cc
+108
-0
paddle/fluid/operators/detection/collect_fpn_proposals_op.cu
paddle/fluid/operators/detection/collect_fpn_proposals_op.cu
+211
-0
paddle/fluid/operators/detection/collect_fpn_proposals_op.h
paddle/fluid/operators/detection/collect_fpn_proposals_op.h
+149
-0
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+66
-0
python/paddle/fluid/tests/test_detection.py
python/paddle/fluid/tests/test_detection.py
+26
-0
python/paddle/fluid/tests/unittests/test_collect_fpn_proposals_op.py
...le/fluid/tests/unittests/test_collect_fpn_proposals_op.py
+100
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
1c6d0646
...
...
@@ -360,6 +360,7 @@ paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5'))
paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace'))
paddle.fluid.layers.accuracy (ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)), ('document', '9808534c12c5e739a10f73ebb0b4eafd'))
paddle.fluid.layers.auc (ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)), ('document', 'e0e95334fce92d16c2d9db6e7caffc47'))
paddle.fluid.layers.exponential_decay (ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)), ('document', '98a5050bee8522fcea81aa795adaba51'))
...
...
paddle/fluid/operators/detection/CMakeLists.txt
浏览文件 @
1c6d0646
...
...
@@ -39,9 +39,11 @@ detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc bo
if
(
WITH_GPU
)
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub
)
detection_library
(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS memory cub
)
detection_library
(
collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS memory cub
)
else
()
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc
)
detection_library
(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc
)
detection_library
(
collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc
)
endif
()
detection_library
(
roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu
)
...
...
paddle/fluid/operators/detection/bbox_util.h
浏览文件 @
1c6d0646
...
...
@@ -22,10 +22,10 @@ namespace paddle {
namespace
operators
{
struct
RangeInitFunctor
{
int
start
_
;
int
delta
_
;
int
*
out
_
;
HOSTDEVICE
void
operator
()(
size_t
i
)
{
out
_
[
i
]
=
start_
+
i
*
delta_
;
}
int
start
;
int
delta
;
int
*
out
;
HOSTDEVICE
void
operator
()(
size_t
i
)
{
out
[
i
]
=
start
+
i
*
delta
;
}
};
template
<
typename
T
>
...
...
paddle/fluid/operators/detection/collect_fpn_proposals_op.cc
0 → 100644
浏览文件 @
1c6d0646
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
class
CollectFpnProposalsOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInputs
(
"MultiLevelRois"
),
"Inputs(MultiLevelRois) shouldn't be null"
);
PADDLE_ENFORCE
(
context
->
HasInputs
(
"MultiLevelScores"
),
"Inputs(MultiLevelScores) shouldn't be null"
);
PADDLE_ENFORCE
(
context
->
HasOutput
(
"FpnRois"
),
"Outputs(MultiFpnRois) of DistributeOp should not be null"
);
auto
roi_dims
=
context
->
GetInputsDim
(
"MultiLevelRois"
);
auto
score_dims
=
context
->
GetInputsDim
(
"MultiLevelScores"
);
auto
post_nms_topN
=
context
->
Attrs
().
Get
<
int
>
(
"post_nms_topN"
);
std
::
vector
<
int64_t
>
out_dims
;
for
(
auto
&
roi_dim
:
roi_dims
)
{
PADDLE_ENFORCE_EQ
(
roi_dim
[
1
],
4
,
"Second dimension of Input(MultiLevelRois) must be 4"
);
}
for
(
auto
&
score_dim
:
score_dims
)
{
PADDLE_ENFORCE_EQ
(
score_dim
[
1
],
1
,
"Second dimension of Input(MultiLevelScores) must be 1"
);
}
context
->
SetOutputDim
(
"FpnRois"
,
{
post_nms_topN
,
4
});
if
(
!
context
->
IsRuntime
())
{
// Runtime LoD infershape will be computed
// in Kernel.
context
->
ShareLoD
(
"MultiLevelRois"
,
"FpnRois"
);
}
if
(
context
->
IsRuntime
())
{
std
::
vector
<
framework
::
InferShapeVarPtr
>
roi_inputs
=
context
->
GetInputVarPtrs
(
"MultiLevelRois"
);
std
::
vector
<
framework
::
InferShapeVarPtr
>
score_inputs
=
context
->
GetInputVarPtrs
(
"MultiLevelScores"
);
for
(
size_t
i
=
0
;
i
<
roi_inputs
.
size
();
++
i
)
{
framework
::
Variable
*
roi_var
=
boost
::
get
<
framework
::
Variable
*>
(
roi_inputs
[
i
]);
framework
::
Variable
*
score_var
=
boost
::
get
<
framework
::
Variable
*>
(
score_inputs
[
i
]);
auto
&
roi_lod
=
roi_var
->
Get
<
LoDTensor
>
().
lod
();
auto
&
score_lod
=
score_var
->
Get
<
LoDTensor
>
().
lod
();
PADDLE_ENFORCE_EQ
(
roi_lod
,
score_lod
,
"Inputs(MultiLevelRois) and Inputs(MultiLevelScores) "
"should have same lod."
);
}
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
MultiInputVar
(
"MultiLevelRois"
)[
0
]);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
());
}
};
class
CollectFpnProposalsOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"MultiLevelRois"
,
"(LoDTensor) Multiple roi LoDTensors from each level in shape "
"(N, 4), N is the number of RoIs"
)
.
AsDuplicable
();
AddInput
(
"MultiLevelScores"
,
"(LoDTensor) Multiple score LoDTensors from each level in shape"
" (N, 1), N is the number of RoIs."
)
.
AsDuplicable
();
AddOutput
(
"FpnRois"
,
"(LoDTensor) All selected RoIs with highest scores"
);
AddAttr
<
int
>
(
"post_nms_topN"
,
"Select post_nms_topN RoIs from"
" all images and all fpn layers"
);
AddComment
(
R"DOC(
This operator concats all proposals from different images
and different FPN levels. Then sort all of those proposals
by objectness confidence. Select the post_nms_topN RoIs in
total. Finally, re-sort the RoIs in the order of batch index.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
collect_fpn_proposals
,
ops
::
CollectFpnProposalsOp
,
ops
::
CollectFpnProposalsOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
collect_fpn_proposals
,
ops
::
CollectFpnProposalsOpKernel
<
float
>
,
ops
::
CollectFpnProposalsOpKernel
<
double
>
);
paddle/fluid/operators/detection/collect_fpn_proposals_op.cu
0 → 100644
浏览文件 @
1c6d0646
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/memory/allocation/allocator.h>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
static
constexpr
int
kNumCUDAThreads
=
64
;
static
constexpr
int
kNumMaxinumNumBlocks
=
4096
;
const
int
kBBoxSize
=
4
;
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaxinumNumBlocks
);
}
static
__global__
void
GetLengthLoD
(
const
int
nthreads
,
const
int
*
batch_ids
,
int
*
length_lod
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
nthreads
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
platform
::
CudaAtomicAdd
(
length_lod
+
batch_ids
[
i
],
1
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
GPUCollectFpnProposalsOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
roi_ins
=
ctx
.
MultiInput
<
LoDTensor
>
(
"MultiLevelRois"
);
const
auto
score_ins
=
ctx
.
MultiInput
<
LoDTensor
>
(
"MultiLevelScores"
);
auto
fpn_rois
=
ctx
.
Output
<
LoDTensor
>
(
"FpnRois"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
const
int
post_nms_topN
=
ctx
.
Attr
<
int
>
(
"post_nms_topN"
);
// concat inputs along axis = 0
int
roi_offset
=
0
;
int
score_offset
=
0
;
int
total_roi_num
=
0
;
for
(
size_t
i
=
0
;
i
<
roi_ins
.
size
();
++
i
)
{
total_roi_num
+=
roi_ins
[
i
]
->
dims
()[
0
];
}
int
real_post_num
=
min
(
post_nms_topN
,
total_roi_num
);
fpn_rois
->
mutable_data
<
T
>
({
real_post_num
,
kBBoxSize
},
dev_ctx
.
GetPlace
());
Tensor
concat_rois
;
Tensor
concat_scores
;
T
*
concat_rois_data
=
concat_rois
.
mutable_data
<
T
>
(
{
total_roi_num
,
kBBoxSize
},
dev_ctx
.
GetPlace
());
T
*
concat_scores_data
=
concat_scores
.
mutable_data
<
T
>
({
total_roi_num
,
1
},
dev_ctx
.
GetPlace
());
Tensor
roi_batch_id_list
;
roi_batch_id_list
.
Resize
({
total_roi_num
});
int
*
roi_batch_id_data
=
roi_batch_id_list
.
mutable_data
<
int
>
(
platform
::
CPUPlace
());
int
index
=
0
;
int
lod_size
;
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx
.
GetPlace
());
for
(
size_t
i
=
0
;
i
<
roi_ins
.
size
();
++
i
)
{
auto
roi_in
=
roi_ins
[
i
];
auto
score_in
=
score_ins
[
i
];
auto
roi_lod
=
roi_in
->
lod
().
back
();
lod_size
=
roi_lod
.
size
()
-
1
;
for
(
size_t
n
=
0
;
n
<
lod_size
;
++
n
)
{
for
(
size_t
j
=
roi_lod
[
n
];
j
<
roi_lod
[
n
+
1
];
++
j
)
{
roi_batch_id_data
[
index
++
]
=
n
;
}
}
memory
::
Copy
(
place
,
concat_rois_data
+
roi_offset
,
place
,
roi_in
->
data
<
T
>
(),
roi_in
->
numel
()
*
sizeof
(
T
),
dev_ctx
.
stream
());
memory
::
Copy
(
place
,
concat_scores_data
+
score_offset
,
place
,
score_in
->
data
<
T
>
(),
score_in
->
numel
()
*
sizeof
(
T
),
dev_ctx
.
stream
());
roi_offset
+=
roi_in
->
numel
();
score_offset
+=
score_in
->
numel
();
}
// copy batch id list to GPU
Tensor
roi_batch_id_list_gpu
;
framework
::
TensorCopy
(
roi_batch_id_list
,
dev_ctx
.
GetPlace
(),
&
roi_batch_id_list_gpu
);
Tensor
index_in_t
;
int
*
idx_in
=
index_in_t
.
mutable_data
<
int
>
({
total_roi_num
},
dev_ctx
.
GetPlace
());
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range_total
(
dev_ctx
,
total_roi_num
);
for_range_total
(
RangeInitFunctor
{
0
,
1
,
idx_in
});
Tensor
keys_out_t
;
T
*
keys_out
=
keys_out_t
.
mutable_data
<
T
>
({
total_roi_num
},
dev_ctx
.
GetPlace
());
Tensor
index_out_t
;
int
*
idx_out
=
index_out_t
.
mutable_data
<
int
>
({
total_roi_num
},
dev_ctx
.
GetPlace
());
// Determine temporary device storage requirements
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
nullptr
,
temp_storage_bytes
,
concat_scores
.
data
<
T
>
(),
keys_out
,
idx_in
,
idx_out
,
total_roi_num
);
// Allocate temporary storage
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
,
memory
::
Allocator
::
kScratchpad
);
// Run sorting operation
// sort score to get corresponding index
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
concat_scores
.
data
<
T
>
(),
keys_out
,
idx_in
,
idx_out
,
total_roi_num
);
index_out_t
.
Resize
({
real_post_num
});
Tensor
sorted_rois
;
sorted_rois
.
mutable_data
<
T
>
({
real_post_num
,
kBBoxSize
},
dev_ctx
.
GetPlace
());
Tensor
sorted_batch_id
;
sorted_batch_id
.
mutable_data
<
int
>
({
real_post_num
},
dev_ctx
.
GetPlace
());
GPUGather
<
T
>
(
dev_ctx
,
concat_rois
,
index_out_t
,
&
sorted_rois
);
GPUGather
<
int
>
(
dev_ctx
,
roi_batch_id_list_gpu
,
index_out_t
,
&
sorted_batch_id
);
Tensor
batch_index_t
;
int
*
batch_idx_in
=
batch_index_t
.
mutable_data
<
int
>
({
real_post_num
},
dev_ctx
.
GetPlace
());
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range_post
(
dev_ctx
,
real_post_num
);
for_range_post
(
RangeInitFunctor
{
0
,
1
,
batch_idx_in
});
Tensor
out_id_t
;
int
*
out_id_data
=
out_id_t
.
mutable_data
<
int
>
({
real_post_num
},
dev_ctx
.
GetPlace
());
// Determine temporary device storage requirements
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
nullptr
,
temp_storage_bytes
,
sorted_batch_id
.
data
<
int
>
(),
out_id_data
,
batch_idx_in
,
index_out_t
.
data
<
int
>
(),
real_post_num
);
// Allocate temporary storage
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
,
memory
::
Allocator
::
kScratchpad
);
// Run sorting operation
// sort batch_id to get corresponding index
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
sorted_batch_id
.
data
<
int
>
(),
out_id_data
,
batch_idx_in
,
index_out_t
.
data
<
int
>
(),
real_post_num
);
GPUGather
<
T
>
(
dev_ctx
,
sorted_rois
,
index_out_t
,
fpn_rois
);
Tensor
length_lod
;
int
*
length_lod_data
=
length_lod
.
mutable_data
<
int
>
({
lod_size
},
dev_ctx
.
GetPlace
());
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
int
>
set_zero
;
set_zero
(
dev_ctx
,
&
length_lod
,
static_cast
<
int
>
(
0
));
int
blocks
=
NumBlocks
(
real_post_num
);
int
threads
=
kNumCUDAThreads
;
// get length-based lod by batch ids
GetLengthLoD
<<<
blocks
,
threads
>>>
(
real_post_num
,
out_id_data
,
length_lod_data
);
std
::
vector
<
int
>
length_lod_cpu
(
lod_size
);
memory
::
Copy
(
platform
::
CPUPlace
(),
length_lod_cpu
.
data
(),
place
,
length_lod_data
,
sizeof
(
int
)
*
lod_size
,
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
std
::
vector
<
size_t
>
offset
(
1
,
0
);
for
(
int
i
=
0
;
i
<
lod_size
;
++
i
)
{
offset
.
emplace_back
(
offset
.
back
()
+
length_lod_cpu
[
i
]);
}
framework
::
LoD
lod
;
lod
.
emplace_back
(
offset
);
fpn_rois
->
set_lod
(
lod
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
collect_fpn_proposals
,
ops
::
GPUCollectFpnProposalsOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
GPUCollectFpnProposalsOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/detection/collect_fpn_proposals_op.h
0 → 100644
浏览文件 @
1c6d0646
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.*/
#pragma once
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
const
int
kBoxDim
=
4
;
template
<
typename
T
>
struct
ScoreWithID
{
T
score
;
int
batch_id
;
int
index
;
int
level
;
ScoreWithID
()
{
batch_id
=
-
1
;
index
=
-
1
;
level
=
-
1
;
}
ScoreWithID
(
T
score_
,
int
batch_id_
,
int
index_
,
int
level_
)
{
score
=
score_
;
batch_id
=
batch_id_
;
index
=
index_
;
level
=
level_
;
}
};
template
<
typename
T
>
static
inline
bool
CompareByScore
(
ScoreWithID
<
T
>
a
,
ScoreWithID
<
T
>
b
)
{
return
a
.
score
>=
b
.
score
;
}
template
<
typename
T
>
static
inline
bool
CompareByBatchid
(
ScoreWithID
<
T
>
a
,
ScoreWithID
<
T
>
b
)
{
return
a
.
batch_id
<
b
.
batch_id
;
}
template
<
typename
T
>
class
CollectFpnProposalsOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
multi_layer_rois
=
context
.
MultiInput
<
paddle
::
framework
::
LoDTensor
>
(
"MultiLevelRois"
);
auto
multi_layer_scores
=
context
.
MultiInput
<
paddle
::
framework
::
LoDTensor
>
(
"MultiLevelScores"
);
auto
*
fpn_rois
=
context
.
Output
<
paddle
::
framework
::
LoDTensor
>
(
"FpnRois"
);
int
post_nms_topN
=
context
.
Attr
<
int
>
(
"post_nms_topN"
);
PADDLE_ENFORCE_GE
(
post_nms_topN
,
0UL
,
"The parameter post_nms_topN must be a positive integer"
);
// assert that the length of Rois and scores are same
PADDLE_ENFORCE
(
multi_layer_rois
.
size
()
==
multi_layer_scores
.
size
(),
"DistributeFpnProposalsOp need 1 level of LoD"
);
// Check if the lod information of two LoDTensor is same
const
int
num_fpn_level
=
multi_layer_rois
.
size
();
std
::
vector
<
int
>
integral_of_all_rois
(
num_fpn_level
+
1
,
0
);
for
(
int
i
=
0
;
i
<
num_fpn_level
;
++
i
)
{
auto
cur_rois_lod
=
multi_layer_rois
[
i
]
->
lod
().
back
();
integral_of_all_rois
[
i
+
1
]
=
integral_of_all_rois
[
i
]
+
cur_rois_lod
[
cur_rois_lod
.
size
()
-
1
];
}
// concatenate all fpn rois scores into a list
// create a vector to store all scores
std
::
vector
<
ScoreWithID
<
T
>>
scores_of_all_rois
(
integral_of_all_rois
[
num_fpn_level
],
ScoreWithID
<
T
>
());
for
(
int
i
=
0
;
i
<
num_fpn_level
;
++
i
)
{
const
T
*
cur_level_scores
=
multi_layer_scores
[
i
]
->
data
<
T
>
();
int
cur_level_num
=
integral_of_all_rois
[
i
+
1
]
-
integral_of_all_rois
[
i
];
auto
cur_scores_lod
=
multi_layer_scores
[
i
]
->
lod
().
back
();
int
cur_batch_id
=
0
;
for
(
int
j
=
0
;
j
<
cur_level_num
;
++
j
)
{
if
(
j
>=
cur_scores_lod
[
cur_batch_id
+
1
])
{
cur_batch_id
++
;
}
int
cur_index
=
j
+
integral_of_all_rois
[
i
];
scores_of_all_rois
[
cur_index
].
score
=
cur_level_scores
[
j
];
scores_of_all_rois
[
cur_index
].
index
=
j
;
scores_of_all_rois
[
cur_index
].
level
=
i
;
scores_of_all_rois
[
cur_index
].
batch_id
=
cur_batch_id
;
}
}
// keep top post_nms_topN rois
// sort the rois by the score
if
(
post_nms_topN
>
integral_of_all_rois
[
num_fpn_level
])
{
post_nms_topN
=
integral_of_all_rois
[
num_fpn_level
];
}
std
::
stable_sort
(
scores_of_all_rois
.
begin
(),
scores_of_all_rois
.
end
(),
CompareByScore
<
T
>
);
scores_of_all_rois
.
resize
(
post_nms_topN
);
// sort by batch id
std
::
stable_sort
(
scores_of_all_rois
.
begin
(),
scores_of_all_rois
.
end
(),
CompareByBatchid
<
T
>
);
// create a pointer array
std
::
vector
<
const
T
*>
multi_fpn_rois_data
(
num_fpn_level
);
for
(
int
i
=
0
;
i
<
num_fpn_level
;
++
i
)
{
multi_fpn_rois_data
[
i
]
=
multi_layer_rois
[
i
]
->
data
<
T
>
();
}
// initialize the outputs
fpn_rois
->
mutable_data
<
T
>
({
post_nms_topN
,
kBoxDim
},
context
.
GetPlace
());
T
*
fpn_rois_data
=
fpn_rois
->
data
<
T
>
();
std
::
vector
<
size_t
>
lod0
(
1
,
0
);
int
cur_batch_id
=
0
;
for
(
int
i
=
0
;
i
<
post_nms_topN
;
++
i
)
{
int
cur_fpn_level
=
scores_of_all_rois
[
i
].
level
;
int
cur_level_index
=
scores_of_all_rois
[
i
].
index
;
memcpy
(
fpn_rois_data
,
multi_fpn_rois_data
[
cur_fpn_level
]
+
cur_level_index
*
kBoxDim
,
kBoxDim
*
sizeof
(
T
));
fpn_rois_data
+=
kBoxDim
;
if
(
scores_of_all_rois
[
i
].
batch_id
!=
cur_batch_id
)
{
cur_batch_id
=
scores_of_all_rois
[
i
].
batch_id
;
lod0
.
emplace_back
(
i
);
}
}
lod0
.
emplace_back
(
post_nms_topN
);
framework
::
LoD
lod
;
lod
.
emplace_back
(
lod0
);
fpn_rois
->
set_lod
(
lod
);
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/detection.py
浏览文件 @
1c6d0646
...
...
@@ -54,6 +54,7 @@ __all__ = [
'multiclass_nms'
,
'distribute_fpn_proposals'
,
'box_decoder_and_assign'
,
'collect_fpn_proposals'
,
]
...
...
@@ -2512,3 +2513,68 @@ def box_decoder_and_assign(prior_box,
"OutputAssignBox"
:
output_assign_box
})
return
decoded_box
,
output_assign_box
def
collect_fpn_proposals
(
multi_rois
,
multi_scores
,
min_level
,
max_level
,
post_nms_top_n
,
name
=
None
):
"""
Concat multi-level RoIs (Region of Interest) and select N RoIs
with respect to multi_scores. This operation performs the following steps:
1. Choose num_level RoIs and scores as input: num_level = max_level - min_level
2. Concat multi-level RoIs and scores
3. Sort scores and select post_nms_top_n scores
4. Gather RoIs by selected indices from scores
5. Re-sort RoIs by corresponding batch_id
Args:
multi_ros(list): List of RoIs to collect
multi_scores(list): List of scores
min_level(int): The lowest level of FPN layer to collect
max_level(int): The highest level of FPN layer to collect
post_nms_top_n(int): The number of selected RoIs
name(str|None): A name for this layer(optional)
Returns:
Variable: Output variable of selected RoIs.
Examples:
.. code-block:: python
multi_rois = []
multi_scores = []
for i in range(4):
multi_rois.append(fluid.layers.data(
name='roi_'+str(i), shape=[4], dtype='float32', lod_level=1))
for i in range(4):
multi_scores.append(fluid.layers.data(
name='score_'+str(i), shape=[1], dtype='float32', lod_level=1))
fpn_rois = fluid.layers.collect_fpn_proposals(
multi_rois=multi_rois,
multi_scores=multi_scores,
min_level=2,
max_level=5,
post_nms_top_n=2000)
"""
helper
=
LayerHelper
(
'collect_fpn_proposals'
,
**
locals
())
dtype
=
helper
.
input_dtype
(
'multi_rois'
)
num_lvl
=
max_level
-
min_level
+
1
input_rois
=
multi_rois
[:
num_lvl
]
input_scores
=
multi_scores
[:
num_lvl
]
output_rois
=
helper
.
create_variable_for_type_inference
(
dtype
)
output_rois
.
stop_gradient
=
True
helper
.
append_op
(
type
=
'collect_fpn_proposals'
,
inputs
=
{
'MultiLevelRois'
:
input_rois
,
'MultiLevelScores'
:
input_scores
},
outputs
=
{
'FpnRois'
:
output_rois
},
attrs
=
{
'post_nms_topN'
:
post_nms_top_n
})
return
output_rois
python/paddle/fluid/tests/test_detection.py
浏览文件 @
1c6d0646
...
...
@@ -522,6 +522,32 @@ class TestMulticlassNMS(unittest.TestCase):
self
.
assertIsNotNone
(
output
)
class
TestCollectFpnPropsals
(
unittest
.
TestCase
):
def
test_collect_fpn_proposals
(
self
):
program
=
Program
()
with
program_guard
(
program
):
multi_bboxes
=
[]
multi_scores
=
[]
for
i
in
range
(
4
):
bboxes
=
layers
.
data
(
name
=
'rois'
+
str
(
i
),
shape
=
[
10
,
4
],
dtype
=
'float32'
,
lod_level
=
1
,
append_batch_size
=
False
)
scores
=
layers
.
data
(
name
=
'scores'
+
str
(
i
),
shape
=
[
10
,
1
],
dtype
=
'float32'
,
lod_level
=
1
,
append_batch_size
=
False
)
multi_bboxes
.
append
(
bboxes
)
multi_scores
.
append
(
scores
)
fpn_rois
=
layers
.
collect_fpn_proposals
(
multi_bboxes
,
multi_scores
,
2
,
5
,
10
)
self
.
assertIsNotNone
(
fpn_rois
)
class
TestDistributeFpnProposals
(
unittest
.
TestCase
):
def
test_distribute_fpn_proposals
(
self
):
program
=
Program
()
...
...
python/paddle/fluid/tests/unittests/test_collect_fpn_proposals_op.py
0 → 100644
浏览文件 @
1c6d0646
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
math
import
sys
from
op_test
import
OpTest
class
TestCollectFPNProposalstOp
(
OpTest
):
def
set_data
(
self
):
self
.
init_test_case
()
self
.
make_rois
()
self
.
scores_input
=
[(
'y%d'
%
i
,
(
self
.
scores
[
i
].
reshape
(
-
1
,
1
),
self
.
rois_lod
[
i
]))
for
i
in
range
(
self
.
num_level
)]
self
.
rois
,
self
.
lod
=
self
.
calc_rois_collect
()
inputs_x
=
[(
'x%d'
%
i
,
(
self
.
roi_inputs
[
i
][:,
1
:],
self
.
rois_lod
[
i
]))
for
i
in
range
(
self
.
num_level
)]
self
.
inputs
=
{
'MultiLevelRois'
:
inputs_x
,
"MultiLevelScores"
:
self
.
scores_input
}
self
.
attrs
=
{
'post_nms_topN'
:
self
.
post_nms_top_n
,
}
self
.
outputs
=
{
'FpnRois'
:
(
self
.
rois
,
[
self
.
lod
])}
def
init_test_case
(
self
):
self
.
post_nms_top_n
=
20
self
.
images_shape
=
[
100
,
100
]
def
resort_roi_by_batch_id
(
self
,
rois
):
batch_id_list
=
rois
[:,
0
]
batch_size
=
int
(
batch_id_list
.
max
())
sorted_rois
=
[]
new_lod
=
[]
for
batch_id
in
range
(
batch_size
+
1
):
sub_ind
=
np
.
where
(
batch_id_list
==
batch_id
)[
0
]
sub_rois
=
rois
[
sub_ind
,
1
:]
sorted_rois
.
append
(
sub_rois
)
new_lod
.
append
(
len
(
sub_rois
))
new_rois
=
np
.
concatenate
(
sorted_rois
)
return
new_rois
,
new_lod
def
calc_rois_collect
(
self
):
roi_inputs
=
np
.
concatenate
(
self
.
roi_inputs
)
scores
=
np
.
concatenate
(
self
.
scores
)
inds
=
np
.
argsort
(
-
scores
)[:
self
.
post_nms_top_n
]
rois
=
roi_inputs
[
inds
,
:]
new_rois
,
new_lod
=
self
.
resort_roi_by_batch_id
(
rois
)
return
new_rois
,
new_lod
def
make_rois
(
self
):
self
.
num_level
=
4
self
.
roi_inputs
=
[]
self
.
scores
=
[]
self
.
rois_lod
=
[[[
20
,
10
]],
[[
30
,
20
]],
[[
20
,
30
]],
[[
10
,
10
]]]
for
lvl
in
range
(
self
.
num_level
):
rois
=
[]
scores_pb
=
[]
lod
=
self
.
rois_lod
[
lvl
][
0
]
bno
=
0
for
roi_num
in
lod
:
for
i
in
range
(
roi_num
):
xywh
=
np
.
random
.
rand
(
4
)
xy1
=
xywh
[
0
:
2
]
*
20
wh
=
xywh
[
2
:
4
]
*
(
self
.
images_shape
-
xy1
)
xy2
=
xy1
+
wh
roi
=
[
bno
,
xy1
[
0
],
xy1
[
1
],
xy2
[
0
],
xy2
[
1
]]
rois
.
append
(
roi
)
bno
+=
1
scores_pb
.
extend
(
list
(
np
.
random
.
uniform
(
0.0
,
1.0
,
roi_num
)))
rois
=
np
.
array
(
rois
).
astype
(
"float32"
)
self
.
roi_inputs
.
append
(
rois
)
scores_pb
=
np
.
array
(
scores_pb
).
astype
(
"float32"
)
self
.
scores
.
append
(
scores_pb
)
def
setUp
(
self
):
self
.
op_type
=
"collect_fpn_proposals"
self
.
set_data
()
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录