Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c8610739
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c8610739
编写于
3月 08, 2019
作者:
C
ceci3
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into npair_loss0
上级
23a9035b
5bde1202
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
671 addition
and
0 deletion
+671
-0
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/detection/CMakeLists.txt
paddle/fluid/operators/detection/CMakeLists.txt
+2
-0
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc
.../fluid/operators/detection/distribute_fpn_proposals_op.cc
+93
-0
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu
.../fluid/operators/detection/distribute_fpn_proposals_op.cu
+221
-0
paddle/fluid/operators/detection/distribute_fpn_proposals_op.h
...e/fluid/operators/detection/distribute_fpn_proposals_op.h
+147
-0
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+74
-0
python/paddle/fluid/tests/test_detection.py
python/paddle/fluid/tests/test_detection.py
+16
-0
python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py
...fluid/tests/unittests/test_distribute_fpn_proposals_op.py
+117
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
c8610739
...
...
@@ -330,6 +330,7 @@ paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varar
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '991e934c3e09abf0edec7c9c978b4691'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '005a5ae47d6c8fff721931d69d072b9f'))
paddle.fluid.layers.accuracy (ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)), ('document', '9808534c12c5e739a10f73ebb0b4eafd'))
paddle.fluid.layers.auc (ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)), ('document', 'e0e95334fce92d16c2d9db6e7caffc47'))
...
...
paddle/fluid/operators/detection/CMakeLists.txt
浏览文件 @
c8610739
...
...
@@ -37,8 +37,10 @@ detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc bo
if
(
WITH_GPU
)
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub
)
detection_library
(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS memory cub
)
else
()
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc
)
detection_library
(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc
)
endif
()
detection_library
(
roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu
)
...
...
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc
0 → 100644
浏览文件 @
c8610739
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
namespace
paddle
{
namespace
operators
{
class
DistributeFpnProposalsOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"FpnRois"
),
"Input(FpnRois) shouldn't be null"
);
PADDLE_ENFORCE_GE
(
ctx
->
Outputs
(
"MultiFpnRois"
).
size
(),
1UL
,
"Outputs(MultiFpnRois) of DistributeOp should not be empty"
);
size_t
min_level
=
static_cast
<
size_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"min_level"
));
size_t
max_level
=
static_cast
<
size_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"max_level"
));
PADDLE_ENFORCE_GE
(
max_level
,
min_level
,
"max_level must not lower than min_level"
);
// Set the output shape
size_t
num_out_rois
=
max_level
-
min_level
+
1
;
std
::
vector
<
framework
::
DDim
>
outs_dims
;
outs_dims
.
reserve
(
num_out_rois
);
for
(
size_t
i
=
0
;
i
<
num_out_rois
;
++
i
)
{
framework
::
DDim
out_dim
=
{
-
1
,
4
};
outs_dims
.
push_back
(
out_dim
);
}
ctx
->
SetOutputsDim
(
"MultiFpnRois"
,
outs_dims
);
ctx
->
SetOutputDim
(
"RestoreIndex"
,
{
1
,
-
1
});
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"FpnRois"
));
return
framework
::
OpKernelType
(
data_type
,
platform
::
CPUPlace
());
}
};
class
DistributeFpnProposalsOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"FpnRois"
,
"(LoDTensor) The rois at all levels in shape (-1, 4)"
);
AddOutput
(
"MultiFpnRois"
,
"(LoDTensor) Output with distribute operator"
)
.
AsDuplicable
();
AddOutput
(
"RestoreIndex"
,
"(Tensor) An array of positive number which is "
"used to restore the order of FpnRois"
);
AddAttr
<
int
>
(
"min_level"
,
"The lowest level of FPN layer where the"
" proposals come from"
);
AddAttr
<
int
>
(
"max_level"
,
"The highest level of FPN layer where the"
" proposals come from"
);
AddAttr
<
int
>
(
"refer_level"
,
"The referring level of FPN layer with"
" specified scale"
);
AddAttr
<
int
>
(
"refer_scale"
,
"The referring scale of FPN layer with"
" specified level"
);
AddComment
(
R"DOC(
This operator distribute all proposals into different fpn level,
with respect to scale of the proposals, the referring scale and
the referring level. Besides, to restore the order of proposals,
we return an array which indicate the original index of rois in
current proposals.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
distribute_fpn_proposals
,
ops
::
DistributeFpnProposalsOp
,
ops
::
DistributeFpnProposalsOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
distribute_fpn_proposals
,
ops
::
DistributeFpnProposalsOpKernel
<
float
>
,
ops
::
DistributeFpnProposalsOpKernel
<
double
>
);
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu
0 → 100644
浏览文件 @
c8610739
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/memory/allocation/allocator.h>
#include "cub/cub.cuh"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumMaxinumNumBlocks
=
4096
;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
int
const
BBoxSize
=
4
;
struct
RangeInitFunctor
{
int
start_
;
int
delta_
;
int
*
out_
;
__device__
void
operator
()(
size_t
i
)
{
out_
[
i
]
=
start_
+
i
*
delta_
;
}
};
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaxinumNumBlocks
);
}
static
inline
void
TransLoD
(
const
int
*
length_lod
,
const
int
lod_size
,
int
*
offset_lod
)
{
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
lod_size
;
++
i
)
{
offset_lod
[
i
]
=
offset
;
offset
+=
length_lod
[
i
];
}
}
template
<
typename
T
>
static
__device__
inline
T
RoIArea
(
const
T
*
box
,
bool
normalized
)
{
if
(
box
[
2
]
<
box
[
0
]
||
box
[
3
]
<
box
[
1
])
{
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return
static_cast
<
T
>
(
0.
);
}
else
{
const
T
w
=
box
[
2
]
-
box
[
0
];
const
T
h
=
box
[
3
]
-
box
[
1
];
if
(
normalized
)
{
return
w
*
h
;
}
else
{
// If coordinate values are not within range [0, 1].
return
(
w
+
1
)
*
(
h
+
1
);
}
}
}
template
<
class
T
>
static
__global__
void
GPUDistFpnProposalsHelper
(
const
int
nthreads
,
const
T
*
rois
,
const
int
lod_size
,
const
int
refer_level
,
const
int
refer_scale
,
const
int
max_level
,
const
int
min_level
,
int
*
roi_batch_id_data
,
int
*
sub_lod_list
,
int
*
target_lvls
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
nthreads
)
{
const
T
*
offset_roi
=
rois
+
i
*
BBoxSize
;
int
roi_batch_ind
=
roi_batch_id_data
[
i
];
// get the target level of current rois
T
roi_area
=
RoIArea
(
offset_roi
,
false
);
T
roi_scale
=
sqrt
(
roi_area
);
int
tgt_lvl
=
floor
(
log2
(
roi_scale
/
refer_scale
)
+
refer_level
);
tgt_lvl
=
min
(
max_level
,
max
(
tgt_lvl
,
min_level
));
target_lvls
[
i
]
=
tgt_lvl
;
// compute number of rois in the same batch and same target level
platform
::
CudaAtomicAdd
(
sub_lod_list
+
tgt_lvl
*
lod_size
+
roi_batch_ind
,
1
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
GPUDistributeFpnProposalsOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
fpn_rois
=
ctx
.
Input
<
paddle
::
framework
::
LoDTensor
>
(
"FpnRois"
);
auto
multi_fpn_rois
=
ctx
.
MultiOutput
<
LoDTensor
>
(
"MultiFpnRois"
);
auto
*
restore_index
=
ctx
.
Output
<
Tensor
>
(
"RestoreIndex"
);
const
int
min_level
=
ctx
.
Attr
<
int
>
(
"min_level"
);
const
int
max_level
=
ctx
.
Attr
<
int
>
(
"max_level"
);
const
int
refer_level
=
ctx
.
Attr
<
int
>
(
"refer_level"
);
const
int
refer_scale
=
ctx
.
Attr
<
int
>
(
"refer_scale"
);
int
num_level
=
max_level
-
min_level
+
1
;
// check that the fpn_rois is not empty
PADDLE_ENFORCE_EQ
(
fpn_rois
->
lod
().
size
(),
1UL
,
"DistributeFpnProposalsOp need 1 level of LoD"
);
auto
fpn_rois_lod
=
fpn_rois
->
lod
().
back
();
int
lod_size
=
fpn_rois_lod
.
size
()
-
1
;
int
roi_num
=
fpn_rois_lod
[
lod_size
];
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
// get batch id by lod in CPU
Tensor
roi_batch_id_list
;
roi_batch_id_list
.
Resize
({
roi_num
});
int
*
roi_batch_id_data
=
roi_batch_id_list
.
mutable_data
<
int
>
(
platform
::
CPUPlace
());
for
(
int
n
=
0
;
n
<
lod_size
;
++
n
)
{
for
(
size_t
i
=
fpn_rois_lod
[
n
];
i
<
fpn_rois_lod
[
n
+
1
];
++
i
)
{
roi_batch_id_data
[
i
]
=
n
;
}
}
// copy batch id list to GPU
Tensor
roi_batch_id_list_gpu
;
framework
::
TensorCopySync
(
roi_batch_id_list
,
dev_ctx
.
GetPlace
(),
&
roi_batch_id_list_gpu
);
Tensor
sub_lod_list
;
sub_lod_list
.
Resize
({
num_level
,
lod_size
});
int
*
sub_lod_list_data
=
sub_lod_list
.
mutable_data
<
int
>
(
dev_ctx
.
GetPlace
());
Tensor
target_lvls
;
target_lvls
.
Resize
({
roi_num
});
int
*
target_lvls_data
=
target_lvls
.
mutable_data
<
int
>
(
dev_ctx
.
GetPlace
());
int
blocks
=
NumBlocks
(
roi_num
);
int
threads
=
kNumCUDAThreads
;
// get target levels and sub_lod list
GPUDistFpnProposalsHelper
<
T
><<<
blocks
,
threads
>>>
(
roi_num
,
fpn_rois
->
data
<
T
>
(),
lod_size
,
refer_level
,
refer_scale
,
max_level
,
min_level
,
roi_batch_id_list_gpu
.
data
<
int
>
(),
sub_lod_list_data
,
target_lvls_data
);
Tensor
index_in_t
;
int
*
idx_in
=
index_in_t
.
mutable_data
<
int
>
({
roi_num
},
dev_ctx
.
GetPlace
());
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
dev_ctx
,
roi_num
);
for_range
(
RangeInitFunctor
{
0
,
1
,
idx_in
});
Tensor
keys_out_t
;
int
*
keys_out
=
keys_out_t
.
mutable_data
<
int
>
({
roi_num
},
dev_ctx
.
GetPlace
());
Tensor
index_out_t
;
int
*
idx_out
=
index_out_t
.
mutable_data
<
int
>
({
roi_num
},
dev_ctx
.
GetPlace
());
// Determine temporary device storage requirements
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairsDescending
<
int
,
int
>
(
nullptr
,
temp_storage_bytes
,
target_lvls_data
,
keys_out
,
idx_in
,
idx_out
,
roi_num
);
// Allocate temporary storage
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx
.
GetPlace
());
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
,
memory
::
Allocator
::
kScratchpad
);
// Run sorting operation
// sort target level to get corresponding index
cub
::
DeviceRadixSort
::
SortPairsDescending
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
target_lvls_data
,
keys_out
,
idx_in
,
idx_out
,
roi_num
);
int
*
restore_idx_data
=
restore_index
->
mutable_data
<
int
>
({
roi_num
,
1
},
dev_ctx
.
GetPlace
());
// sort current index to get restore index
cub
::
DeviceRadixSort
::
SortPairsDescending
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
idx_out
,
keys_out
,
idx_in
,
restore_idx_data
,
roi_num
);
Tensor
offset_lod
;
int
*
offset_lod_data
=
offset_lod
.
mutable_data
<
int
>
({
lod_size
+
1
},
dev_ctx
.
GetPlace
());
for
(
int
i
=
0
;
i
<
num_level
;
++
i
)
{
Tensor
sub_lod
=
sub_lod_list
.
Slice
(
i
,
i
+
1
);
int
*
sub_lod_data
=
sub_lod
.
data
<
int
>
();
// transfer length-based lod to offset-based lod
TransLoD
(
sub_lod_data
,
lod_size
+
1
,
offset_lod_data
);
int
sub_rois_num
=
offset_lod_data
[
lod_size
];
Tensor
sub_idx
=
index_out_t
.
Slice
(
0
,
sub_rois_num
);
multi_fpn_rois
[
i
]
->
mutable_data
<
T
>
({
sub_rois_num
,
kBoxDim
},
dev_ctx
.
GetPlace
());
GPUGather
<
T
>
(
dev_ctx
,
*
fpn_rois
,
sub_idx
,
multi_fpn_rois
[
i
]);
framework
::
LoD
lod
;
std
::
vector
<
size_t
>
offset
;
memory
::
Copy
(
platform
::
CPUPlace
(),
offset
.
data
(),
place
,
offset_lod_data
,
sizeof
(
int
)
*
(
lod_size
+
1
),
0
);
lod
.
emplace_back
(
offset
);
multi_fpn_rois
[
i
]
->
set_lod
(
lod
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
distribute_fpn_proposals
,
ops
::
GPUDistributeFpnProposalsOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
GPUDistributeFpnProposalsOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/detection/distribute_fpn_proposals_op.h
0 → 100644
浏览文件 @
c8610739
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
const
int
kBoxDim
=
4
;
template
<
typename
T
>
static
inline
T
BBoxArea
(
const
T
*
box
,
bool
normalized
)
{
if
(
box
[
2
]
<
box
[
0
]
||
box
[
3
]
<
box
[
1
])
{
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return
static_cast
<
T
>
(
0.
);
}
else
{
const
T
w
=
box
[
2
]
-
box
[
0
];
const
T
h
=
box
[
3
]
-
box
[
1
];
if
(
normalized
)
{
return
w
*
h
;
}
else
{
// If coordinate values are not within range [0, 1].
return
(
w
+
1
)
*
(
h
+
1
);
}
}
}
template
<
typename
T
>
class
DistributeFpnProposalsOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
fpn_rois
=
context
.
Input
<
paddle
::
framework
::
LoDTensor
>
(
"FpnRois"
);
auto
multi_fpn_rois
=
context
.
MultiOutput
<
paddle
::
framework
::
LoDTensor
>
(
"MultiFpnRois"
);
auto
*
restore_index
=
context
.
Output
<
paddle
::
framework
::
Tensor
>
(
"RestoreIndex"
);
const
int
min_level
=
context
.
Attr
<
int
>
(
"min_level"
);
const
int
max_level
=
context
.
Attr
<
int
>
(
"max_level"
);
const
int
refer_level
=
context
.
Attr
<
int
>
(
"refer_level"
);
const
int
refer_scale
=
context
.
Attr
<
int
>
(
"refer_scale"
);
const
int
num_level
=
max_level
-
min_level
+
1
;
// check that the fpn_rois is not empty
PADDLE_ENFORCE_EQ
(
fpn_rois
->
lod
().
size
(),
1UL
,
"DistributeFpnProposalsOp need 1 level of LoD"
);
auto
fpn_rois_lod
=
fpn_rois
->
lod
().
back
();
int
fpn_rois_num
=
fpn_rois_lod
[
fpn_rois_lod
.
size
()
-
1
];
std
::
vector
<
int
>
target_level
;
// std::vector<int> target_level(fpn_rois_num, -1);
// record the number of rois in each level
std
::
vector
<
int
>
num_rois_level
(
num_level
,
0
);
std
::
vector
<
int
>
num_rois_level_integral
(
num_level
+
1
,
0
);
for
(
int
i
=
0
;
i
<
fpn_rois_lod
.
size
()
-
1
;
++
i
)
{
Tensor
fpn_rois_slice
=
fpn_rois
->
Slice
(
fpn_rois_lod
[
i
],
fpn_rois_lod
[
i
+
1
]);
const
T
*
rois_data
=
fpn_rois_slice
.
data
<
T
>
();
for
(
int
j
=
0
;
j
<
fpn_rois_slice
.
dims
()[
0
];
++
j
)
{
// get the target level of current rois
T
roi_scale
=
std
::
sqrt
(
BBoxArea
(
rois_data
,
false
));
int
tgt_lvl
=
std
::
floor
(
std
::
log2
(
roi_scale
/
refer_scale
)
+
refer_level
);
tgt_lvl
=
std
::
min
(
max_level
,
std
::
max
(
tgt_lvl
,
min_level
));
target_level
.
push_back
(
tgt_lvl
);
num_rois_level
[
tgt_lvl
-
min_level
]
++
;
rois_data
+=
kBoxDim
;
}
}
// define the output rois
// pointer which point to each level fpn rois
std
::
vector
<
T
*>
multi_fpn_rois_data
(
num_level
);
// lod0 which will record the offset information of each level rois
std
::
vector
<
std
::
vector
<
size_t
>>
multi_fpn_rois_lod0
;
for
(
int
i
=
0
;
i
<
num_level
;
++
i
)
{
// allocate memory for each level rois
multi_fpn_rois
[
i
]
->
mutable_data
<
T
>
({
num_rois_level
[
i
],
kBoxDim
},
context
.
GetPlace
());
multi_fpn_rois_data
[
i
]
=
multi_fpn_rois
[
i
]
->
data
<
T
>
();
std
::
vector
<
size_t
>
lod0
(
1
,
0
);
multi_fpn_rois_lod0
.
push_back
(
lod0
);
// statistic start point for each level rois
num_rois_level_integral
[
i
+
1
]
=
num_rois_level_integral
[
i
]
+
num_rois_level
[
i
];
}
restore_index
->
mutable_data
<
int
>
({
1
,
fpn_rois_num
},
context
.
GetPlace
());
int
*
restore_index_data
=
restore_index
->
data
<
int
>
();
std
::
vector
<
int
>
restore_index_inter
(
fpn_rois_num
,
-
1
);
// distribute the rois into different fpn level by target level
for
(
int
i
=
0
;
i
<
fpn_rois_lod
.
size
()
-
1
;
++
i
)
{
Tensor
fpn_rois_slice
=
fpn_rois
->
Slice
(
fpn_rois_lod
[
i
],
fpn_rois_lod
[
i
+
1
]);
const
T
*
rois_data
=
fpn_rois_slice
.
data
<
T
>
();
size_t
cur_offset
=
fpn_rois_lod
[
i
];
// std::vector<size_t > lod_offset[num_level];
for
(
int
j
=
0
;
j
<
num_level
;
j
++
)
{
multi_fpn_rois_lod0
[
j
].
push_back
(
multi_fpn_rois_lod0
[
j
][
i
]);
}
for
(
int
j
=
0
;
j
<
fpn_rois_slice
.
dims
()[
0
];
++
j
)
{
int
lvl
=
target_level
[
cur_offset
+
j
];
memcpy
(
multi_fpn_rois_data
[
lvl
-
min_level
],
rois_data
,
kBoxDim
*
sizeof
(
T
));
multi_fpn_rois_data
[
lvl
-
min_level
]
+=
kBoxDim
;
int
index_in_shuffle
=
num_rois_level_integral
[
lvl
-
min_level
]
+
multi_fpn_rois_lod0
[
lvl
-
min_level
][
i
+
1
];
restore_index_inter
[
index_in_shuffle
]
=
cur_offset
+
j
;
multi_fpn_rois_lod0
[
lvl
-
min_level
][
i
+
1
]
++
;
rois_data
+=
kBoxDim
;
}
}
for
(
int
i
=
0
;
i
<
fpn_rois_num
;
++
i
)
{
restore_index_data
[
restore_index_inter
[
i
]]
=
i
;
}
// merge lod information into LoDTensor
for
(
int
i
=
0
;
i
<
num_level
;
++
i
)
{
framework
::
LoD
lod
;
lod
.
emplace_back
(
multi_fpn_rois_lod0
[
i
]);
multi_fpn_rois
[
i
]
->
set_lod
(
lod
);
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/detection.py
浏览文件 @
c8610739
...
...
@@ -51,6 +51,7 @@ __all__ = [
'yolov3_loss'
,
'box_clip'
,
'multiclass_nms'
,
'distribute_fpn_proposals'
,
'box_decoder_and_assign'
,
]
...
...
@@ -2224,6 +2225,79 @@ def multiclass_nms(bboxes,
return
output
def
distribute_fpn_proposals
(
fpn_rois
,
min_level
,
max_level
,
refer_level
,
refer_scale
,
name
=
None
):
"""
In Feature Pyramid Networks (FPN) models, it is needed to distribute all
proposals into different FPN level, with respect to scale of the proposals,
the referring scale and the referring level. Besides, to restore the order
of proposals, we return an array which indicates the original index of rois
in current proposals. To compute FPN level for each roi, the formula is
given as follows:
.. math::
roi\_scale &= \sqrt{BBoxArea(fpn\_roi)}
level = floor(&\log(
\\
frac{roi\_scale}{refer\_scale}) + refer\_level)
where BBoxArea is a function to compute the area of each roi.
Args:
fpn_rois(variable): The input fpn_rois, the second dimension is 4.
min_level(int): The lowest level of FPN layer where the proposals come
from.
max_level(int): The highest level of FPN layer where the proposals
come from.
refer_level(int): The referring level of FPN layer with specified scale.
refer_scale(int): The referring scale of FPN layer with specified level.
name(str|None): The name of this operator.
Returns:
tuple:
A tuple(multi_rois, restore_ind) is returned. The multi_rois is
a list of segmented tensor variables. The restore_ind is a 2D
Tensor with shape [N, 1], N is the number of total rois. It is
used to restore the order of fpn_rois.
Examples:
.. code-block:: python
fpn_rois = fluid.layers.data(
name='data', shape=[4], dtype='float32', lod_level=1)
multi_rois, restore_ind = fluid.layers.distribute_fpn_proposals(
fpn_rois=fpn_rois,
min_level=2,
max_level=5,
refer_level=4,
refer_scale=224)
"""
helper
=
LayerHelper
(
'distribute_fpn_proposals'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
num_lvl
=
max_level
-
min_level
+
1
multi_rois
=
[
helper
.
create_variable_for_type_inference
(
dtype
)
for
i
in
range
(
num_lvl
)
]
restore_ind
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int32'
)
helper
.
append_op
(
type
=
'distribute_fpn_proposals'
,
inputs
=
{
'FpnRois'
:
fpn_rois
},
outputs
=
{
'MultiFpnRois'
:
multi_rois
,
'RestoreIndex'
:
restore_ind
},
attrs
=
{
'min_level'
:
min_level
,
'max_level'
:
max_level
,
'refer_level'
:
refer_level
,
'refer_scale'
:
refer_scale
})
return
multi_rois
,
restore_ind
@
templatedoc
()
def
box_decoder_and_assign
(
prior_box
,
prior_box_var
,
...
...
python/paddle/fluid/tests/test_detection.py
浏览文件 @
c8610739
...
...
@@ -504,5 +504,21 @@ class TestMulticlassNMS(unittest.TestCase):
self
.
assertIsNotNone
(
output
)
class
TestDistributeFpnProposals
(
unittest
.
TestCase
):
def
test_distribute_fpn_proposals
(
self
):
program
=
Program
()
with
program_guard
(
program
):
fpn_rois
=
fluid
.
layers
.
data
(
name
=
'data'
,
shape
=
[
4
],
dtype
=
'float32'
,
lod_level
=
1
)
multi_rois
,
restore_ind
=
layers
.
distribute_fpn_proposals
(
fpn_rois
=
fpn_rois
,
min_level
=
2
,
max_level
=
5
,
refer_level
=
4
,
refer_scale
=
224
)
self
.
assertIsNotNone
(
multi_rois
)
self
.
assertIsNotNone
(
restore_ind
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py
0 → 100644
浏览文件 @
c8610739
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
math
import
sys
from
op_test
import
OpTest
class
TestDistributeFPNProposalsOp
(
OpTest
):
def
set_data
(
self
):
self
.
init_test_case
()
self
.
make_rois
()
self
.
rois_fpn
,
self
.
rois_idx_restore
=
self
.
calc_rois_distribute
()
self
.
inputs
=
{
'FpnRois'
:
(
self
.
rois
[:,
1
:
5
],
self
.
rois_lod
)}
self
.
attrs
=
{
'max_level'
:
self
.
roi_max_level
,
'min_level'
:
self
.
roi_min_level
,
'refer_scale'
:
self
.
canonical_scale
,
'refer_level'
:
self
.
canonical_level
}
output
=
[(
'out%d'
%
i
,
self
.
rois_fpn
[
i
])
for
i
in
range
(
len
(
self
.
rois_fpn
))]
self
.
outputs
=
{
'MultiFpnRois'
:
output
,
'RestoreIndex'
:
self
.
rois_idx_restore
}
def
init_test_case
(
self
):
self
.
roi_max_level
=
5
self
.
roi_min_level
=
2
self
.
canonical_scale
=
224
self
.
canonical_level
=
4
self
.
images_shape
=
[
512
,
512
]
def
boxes_area
(
self
,
boxes
):
w
=
(
boxes
[:,
2
]
-
boxes
[:,
0
]
+
1
)
h
=
(
boxes
[:,
3
]
-
boxes
[:,
1
]
+
1
)
areas
=
w
*
h
assert
np
.
all
(
areas
>=
0
),
'Negative areas founds'
return
areas
def
map_rois_to_fpn_levels
(
self
,
rois
,
lvl_min
,
lvl_max
):
s
=
np
.
sqrt
(
self
.
boxes_area
(
rois
))
s0
=
self
.
canonical_scale
lvl0
=
self
.
canonical_level
target_lvls
=
np
.
floor
(
lvl0
+
np
.
log2
(
s
/
s0
+
1e-6
))
target_lvls
=
np
.
clip
(
target_lvls
,
lvl_min
,
lvl_max
)
return
target_lvls
def
get_sub_lod
(
self
,
sub_lvl
):
sub_lod
=
[]
max_batch_id
=
sub_lvl
[
-
1
]
for
i
in
range
(
max_batch_id
.
astype
(
np
.
int32
)
+
1
):
sub_lod
.
append
(
np
.
where
(
sub_lvl
==
i
)[
0
].
size
)
return
sub_lod
def
add_multilevel_roi
(
self
,
rois
,
target_lvls
,
lvl_min
,
lvl_max
):
rois_idx_order
=
np
.
empty
((
0
,
))
rois_fpn
=
[]
for
lvl
in
range
(
lvl_min
,
lvl_max
+
1
):
idx_lvl
=
np
.
where
(
target_lvls
==
lvl
)[
0
]
if
len
(
idx_lvl
)
==
0
:
rois_fpn
.
append
((
np
.
empty
(
shape
=
(
0
,
4
)),
[[
0
,
0
]]))
continue
sub_lod
=
self
.
get_sub_lod
(
rois
[
idx_lvl
,
0
])
rois_fpn
.
append
((
rois
[
idx_lvl
,
1
:],
[
sub_lod
]))
rois_idx_order
=
np
.
concatenate
((
rois_idx_order
,
idx_lvl
))
rois_idx_restore
=
np
.
argsort
(
rois_idx_order
).
astype
(
np
.
int32
,
copy
=
False
)
return
rois_fpn
,
rois_idx_restore
def
calc_rois_distribute
(
self
):
lvl_min
=
self
.
roi_min_level
lvl_max
=
self
.
roi_max_level
target_lvls
=
self
.
map_rois_to_fpn_levels
(
self
.
rois
[:,
1
:
5
],
lvl_min
,
lvl_max
)
rois_fpn
,
rois_idx_restore
=
self
.
add_multilevel_roi
(
self
.
rois
,
target_lvls
,
lvl_min
,
lvl_max
)
return
rois_fpn
,
rois_idx_restore
def
make_rois
(
self
):
self
.
rois_lod
=
[[
100
,
200
]]
rois
=
[]
lod
=
self
.
rois_lod
[
0
]
bno
=
0
for
roi_num
in
lod
:
for
i
in
range
(
roi_num
):
xywh
=
np
.
random
.
rand
(
4
)
xy1
=
xywh
[
0
:
2
]
*
20
wh
=
xywh
[
2
:
4
]
*
(
self
.
images_shape
-
xy1
)
xy2
=
xy1
+
wh
roi
=
[
bno
,
xy1
[
0
],
xy1
[
1
],
xy2
[
0
],
xy2
[
1
]]
rois
.
append
(
roi
)
bno
+=
1
self
.
rois
=
np
.
array
(
rois
).
astype
(
"float32"
)
def
setUp
(
self
):
self
.
op_type
=
"distribute_fpn_proposals"
self
.
set_data
()
def
test_check_output
(
self
):
self
.
check_output
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录