Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cc95a751
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cc95a751
编写于
5月 06, 2019
作者:
J
jerrywgz
提交者:
GitHub
5月 06, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix distribute fpn proposals, test=develop (#16152)
* fix distribute fpn proposals, test=develop
上级
9ec4615d
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
81 addition
and
70 deletion
+81
-70
paddle/fluid/operators/detection/bbox_util.h
paddle/fluid/operators/detection/bbox_util.h
+26
-0
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc
.../fluid/operators/detection/distribute_fpn_proposals_op.cc
+2
-2
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu
.../fluid/operators/detection/distribute_fpn_proposals_op.cu
+42
-61
paddle/fluid/operators/detection/distribute_fpn_proposals_op.h
...e/fluid/operators/detection/distribute_fpn_proposals_op.h
+3
-3
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+1
-1
python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py
...fluid/tests/unittests/test_distribute_fpn_proposals_op.py
+7
-3
未找到文件。
paddle/fluid/operators/detection/bbox_util.h
浏览文件 @
cc95a751
...
...
@@ -15,11 +15,37 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
operators
{
struct
RangeInitFunctor
{
int
start_
;
int
delta_
;
int
*
out_
;
HOSTDEVICE
void
operator
()(
size_t
i
)
{
out_
[
i
]
=
start_
+
i
*
delta_
;
}
};
template
<
typename
T
>
inline
HOSTDEVICE
T
RoIArea
(
const
T
*
box
,
bool
normalized
)
{
if
(
box
[
2
]
<
box
[
0
]
||
box
[
3
]
<
box
[
1
])
{
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return
static_cast
<
T
>
(
0.
);
}
else
{
const
T
w
=
box
[
2
]
-
box
[
0
];
const
T
h
=
box
[
3
]
-
box
[
1
];
if
(
normalized
)
{
return
w
*
h
;
}
else
{
// If coordinate values are not within range [0, 1].
return
(
w
+
1
)
*
(
h
+
1
);
}
}
}
/*
* transform that computes target bounding-box regression deltas
* given proposal boxes and ground-truth boxes.
...
...
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc
浏览文件 @
cc95a751
...
...
@@ -40,14 +40,14 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
outs_dims
.
push_back
(
out_dim
);
}
ctx
->
SetOutputsDim
(
"MultiFpnRois"
,
outs_dims
);
ctx
->
SetOutputDim
(
"RestoreIndex"
,
{
1
,
-
1
});
ctx
->
SetOutputDim
(
"RestoreIndex"
,
{
-
1
,
1
});
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
framework
::
GetDataTypeOfVar
(
ctx
.
InputVar
(
"FpnRois"
));
return
framework
::
OpKernelType
(
data_type
,
platform
::
CPUPlace
());
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
};
...
...
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu
浏览文件 @
cc95a751
...
...
@@ -15,8 +15,10 @@ limitations under the License. */
#include <paddle/fluid/memory/allocation/allocator.h>
#include "cub/cub.cuh"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h"
...
...
@@ -26,7 +28,7 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
static
constexpr
int
kNumCUDAThreads
=
512
;
static
constexpr
int
kNumCUDAThreads
=
64
;
static
constexpr
int
kNumMaxinumNumBlocks
=
4096
;
#define CUDA_1D_KERNEL_LOOP(i, n) \
...
...
@@ -35,47 +37,13 @@ static constexpr int kNumMaxinumNumBlocks = 4096;
int
const
BBoxSize
=
4
;
struct
RangeInitFunctor
{
int
start_
;
int
delta_
;
int
*
out_
;
__device__
void
operator
()(
size_t
i
)
{
out_
[
i
]
=
start_
+
i
*
delta_
;
}
};
static
inline
int
NumBlocks
(
const
int
N
)
{
return
std
::
min
((
N
+
kNumCUDAThreads
-
1
)
/
kNumCUDAThreads
,
kNumMaxinumNumBlocks
);
}
static
inline
void
TransLoD
(
const
int
*
length_lod
,
const
int
lod_size
,
int
*
offset_lod
)
{
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
lod_size
;
++
i
)
{
offset_lod
[
i
]
=
offset
;
offset
+=
length_lod
[
i
];
}
}
template
<
typename
T
>
static
__device__
inline
T
RoIArea
(
const
T
*
box
,
bool
normalized
)
{
if
(
box
[
2
]
<
box
[
0
]
||
box
[
3
]
<
box
[
1
])
{
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return
static_cast
<
T
>
(
0.
);
}
else
{
const
T
w
=
box
[
2
]
-
box
[
0
];
const
T
h
=
box
[
3
]
-
box
[
1
];
if
(
normalized
)
{
return
w
*
h
;
}
else
{
// If coordinate values are not within range [0, 1].
return
(
w
+
1
)
*
(
h
+
1
);
}
}
}
template
<
class
T
>
static
__global__
void
GPUDistFpnProposalsHelper
(
__global__
void
GPUDistFpnProposalsHelper
(
const
int
nthreads
,
const
T
*
rois
,
const
int
lod_size
,
const
int
refer_level
,
const
int
refer_scale
,
const
int
max_level
,
const
int
min_level
,
int
*
roi_batch_id_data
,
int
*
sub_lod_list
,
...
...
@@ -86,12 +54,13 @@ static __global__ void GPUDistFpnProposalsHelper(
// get the target level of current rois
T
roi_area
=
RoIArea
(
offset_roi
,
false
);
T
roi_scale
=
sqrt
(
roi_area
);
int
tgt_lvl
=
floor
(
log2
(
roi_scale
/
refer_scale
)
+
refer_level
);
int
tgt_lvl
=
floor
(
log2
(
roi_scale
/
static_cast
<
T
>
(
refer_scale
)
+
(
T
)
1e-6
)
+
refer_level
);
tgt_lvl
=
min
(
max_level
,
max
(
tgt_lvl
,
min_level
));
target_lvls
[
i
]
=
tgt_lvl
;
// compute number of rois in the same batch and same target level
platform
::
CudaAtomicAdd
(
sub_lod_list
+
tgt_lvl
*
lod_size
+
roi_batch_ind
,
1
);
platform
::
CudaAtomicAdd
(
sub_lod_list
+
(
tgt_lvl
-
min_level
)
*
lod_size
+
roi_batch_ind
,
1
);
}
}
...
...
@@ -138,18 +107,22 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
Tensor
sub_lod_list
;
sub_lod_list
.
Resize
({
num_level
,
lod_size
});
int
*
sub_lod_list_data
=
sub_lod_list
.
mutable_data
<
int
>
(
dev_ctx
.
GetPlace
());
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
int
>
set_zero
;
set_zero
(
dev_ctx
,
&
sub_lod_list
,
static_cast
<
int
>
(
0
));
Tensor
target_lvls
;
target_lvls
.
Resize
({
roi_num
});
int
*
target_lvls_data
=
target_lvls
.
mutable_data
<
int
>
(
dev_ctx
.
GetPlace
());
int
blocks
=
NumBlocks
(
roi_num
);
int
dist_
blocks
=
NumBlocks
(
roi_num
);
int
threads
=
kNumCUDAThreads
;
// get target levels and sub_lod list
GPUDistFpnProposalsHelper
<
T
><<<
blocks
,
threads
>>>
(
GPUDistFpnProposalsHelper
<
T
><<<
dist_
blocks
,
threads
>>>
(
roi_num
,
fpn_rois
->
data
<
T
>
(),
lod_size
,
refer_level
,
refer_scale
,
max_level
,
min_level
,
roi_batch_id_list_gpu
.
data
<
int
>
(),
sub_lod_list_data
,
target_lvls_data
);
dev_ctx
.
Wait
();
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx
.
GetPlace
());
Tensor
index_in_t
;
int
*
idx_in
=
index_in_t
.
mutable_data
<
int
>
({
roi_num
},
dev_ctx
.
GetPlace
());
...
...
@@ -163,46 +136,54 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
// Determine temporary device storage requirements
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairs
Descending
<
int
,
int
>
(
nullptr
,
temp_storage_bytes
,
target_lvls_data
,
keys_out
,
idx_in
,
idx_out
,
roi_num
);
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
nullptr
,
temp_storage_bytes
,
target_lvls_data
,
keys_out
,
idx_in
,
idx_out
,
roi_num
);
// Allocate temporary storage
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx
.
GetPlace
());
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
,
memory
::
Allocator
::
kScratchpad
);
// Run sorting operation
// sort target level to get corresponding index
cub
::
DeviceRadixSort
::
SortPairs
Descending
<
int
,
int
>
(
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
target_lvls_data
,
keys_out
,
idx_in
,
idx_out
,
roi_num
);
int
*
restore_idx_data
=
restore_index
->
mutable_data
<
int
>
({
roi_num
,
1
},
dev_ctx
.
GetPlace
());
// sort current index to get restore index
cub
::
DeviceRadixSort
::
SortPairs
Descending
<
int
,
int
>
(
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
idx_out
,
keys_out
,
idx_in
,
restore_idx_data
,
roi_num
);
Tensor
offset_lod
;
int
*
offset_lod_data
=
offset_lod
.
mutable_data
<
int
>
({
lod_size
+
1
},
dev_ctx
.
GetPlace
());
int
start
=
0
;
for
(
int
i
=
0
;
i
<
num_level
;
++
i
)
{
Tensor
sub_lod
=
sub_lod_list
.
Slice
(
i
,
i
+
1
);
int
*
sub_lod_data
=
sub_lod
.
data
<
int
>
();
// transfer length-based lod to offset-based lod
TransLoD
(
sub_lod_data
,
lod_size
+
1
,
offset_lod_data
);
int
sub_rois_num
=
offset_lod_data
[
lod_size
];
Tensor
sub_idx
=
index_out_t
.
Slice
(
0
,
sub_rois_num
);
std
::
vector
<
size_t
>
offset
(
1
,
0
);
std
::
vector
<
int
>
sub_lod_cpu
(
lod_size
);
memory
::
Copy
(
platform
::
CPUPlace
(),
sub_lod_cpu
.
data
(),
place
,
sub_lod_data
,
sizeof
(
int
)
*
lod_size
,
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
for
(
int
j
=
0
;
j
<
lod_size
;
++
j
)
{
offset
.
emplace_back
(
offset
.
back
()
+
sub_lod_cpu
[
j
]);
}
int
sub_rois_num
=
offset
.
back
();
int
end
=
start
+
sub_rois_num
;
if
(
end
>
start
)
{
Tensor
sub_idx
=
index_out_t
.
Slice
(
start
,
end
);
start
=
end
;
multi_fpn_rois
[
i
]
->
mutable_data
<
T
>
({
sub_rois_num
,
kBoxDim
},
dev_ctx
.
GetPlace
());
GPUGather
<
T
>
(
dev_ctx
,
*
fpn_rois
,
sub_idx
,
multi_fpn_rois
[
i
]);
}
else
{
multi_fpn_rois
[
i
]
->
mutable_data
<
T
>
({
sub_rois_num
,
kBoxDim
},
dev_ctx
.
GetPlace
());
}
framework
::
LoD
lod
;
std
::
vector
<
size_t
>
offset
;
memory
::
Copy
(
platform
::
CPUPlace
(),
offset
.
data
(),
place
,
offset_lod_data
,
sizeof
(
int
)
*
(
lod_size
+
1
),
0
);
lod
.
emplace_back
(
offset
);
multi_fpn_rois
[
i
]
->
set_lod
(
lod
);
}
...
...
paddle/fluid/operators/detection/distribute_fpn_proposals_op.h
浏览文件 @
cc95a751
...
...
@@ -83,8 +83,8 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
for
(
int
j
=
0
;
j
<
fpn_rois_slice
.
dims
()[
0
];
++
j
)
{
// get the target level of current rois
T
roi_scale
=
std
::
sqrt
(
BBoxArea
(
rois_data
,
false
));
int
tgt_lvl
=
std
::
floor
(
std
::
log2
(
roi_scale
/
refer_scale
)
+
refer_level
);
int
tgt_lvl
=
std
::
floor
(
std
::
log2
(
roi_scale
/
refer_scale
+
(
T
)
1e-6
)
+
refer_level
);
tgt_lvl
=
std
::
min
(
max_level
,
std
::
max
(
tgt_lvl
,
min_level
));
target_level
.
push_back
(
tgt_lvl
);
num_rois_level
[
tgt_lvl
-
min_level
]
++
;
...
...
@@ -107,7 +107,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
num_rois_level_integral
[
i
+
1
]
=
num_rois_level_integral
[
i
]
+
num_rois_level
[
i
];
}
restore_index
->
mutable_data
<
int
>
({
1
,
fpn_rois_num
},
context
.
GetPlace
());
restore_index
->
mutable_data
<
int
>
({
fpn_rois_num
,
1
},
context
.
GetPlace
());
int
*
restore_index_data
=
restore_index
->
data
<
int
>
();
std
::
vector
<
int
>
restore_index_inter
(
fpn_rois_num
,
-
1
);
// distribute the rois into different fpn level by target level
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
cc95a751
...
...
@@ -2383,7 +2383,7 @@ def distribute_fpn_proposals(fpn_rois,
"""
helper
=
LayerHelper
(
'distribute_fpn_proposals'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
dtype
=
helper
.
input_dtype
(
'fpn_rois'
)
num_lvl
=
max_level
-
min_level
+
1
multi_rois
=
[
helper
.
create_variable_for_type_inference
(
dtype
)
for
i
in
range
(
num_lvl
)
...
...
python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py
浏览文件 @
cc95a751
...
...
@@ -37,7 +37,7 @@ class TestDistributeFPNProposalsOp(OpTest):
for
i
in
range
(
len
(
self
.
rois_fpn
))]
self
.
outputs
=
{
'MultiFpnRois'
:
output
,
'RestoreIndex'
:
self
.
rois_idx_restore
'RestoreIndex'
:
self
.
rois_idx_restore
.
reshape
(
-
1
,
1
)
}
def
init_test_case
(
self
):
...
...
@@ -63,10 +63,10 @@ class TestDistributeFPNProposalsOp(OpTest):
return
target_lvls
def
get_sub_lod
(
self
,
sub_lvl
):
sub_lod
=
[]
sub_lod
=
[
0
,
0
]
max_batch_id
=
sub_lvl
[
-
1
]
for
i
in
range
(
max_batch_id
.
astype
(
np
.
int32
)
+
1
):
sub_lod
.
append
(
np
.
where
(
sub_lvl
==
i
)[
0
].
size
)
sub_lod
[
i
]
=
np
.
where
(
sub_lvl
==
i
)[
0
].
size
return
sub_lod
def
add_multilevel_roi
(
self
,
rois
,
target_lvls
,
lvl_min
,
lvl_max
):
...
...
@@ -115,3 +115,7 @@ class TestDistributeFPNProposalsOp(OpTest):
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录