Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c8e49be2
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c8e49be2
编写于
10月 28, 2019
作者:
W
whs
提交者:
GitHub
10月 28, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix roi_perspective_transform op (#20764)
上级
6bdf99d3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
34 addition
and
33 deletion
+34
-33
paddle/fluid/operators/detection/roi_perspective_transform_op.cc
...fluid/operators/detection/roi_perspective_transform_op.cc
+16
-16
paddle/fluid/operators/detection/roi_perspective_transform_op.cu
...fluid/operators/detection/roi_perspective_transform_op.cu
+12
-11
python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py
...luid/tests/unittests/test_roi_perspective_transform_op.py
+6
-6
未找到文件。
paddle/fluid/operators/detection/roi_perspective_transform_op.cc
浏览文件 @
c8e49be2
...
...
@@ -187,17 +187,17 @@ void bilinear_interpolate(const T* in_data, const int channels, const int width,
const
int
height
,
int
in_n
,
int
in_c
,
T
in_w
,
T
in_h
,
T
*
val
)
{
// Deal with cases that source coords are out of feature map boundary
if
(
GT
<
T
>
(
-
0.5
,
in_w
)
||
GT
<
T
>
(
in_w
,
width
-
0.5
)
||
GT
<
T
>
(
-
0.5
,
in_h
)
||
GT
<
T
>
(
in_h
,
height
-
0.5
))
{
if
(
GT
_E
<
T
>
(
-
0.5
,
in_w
)
||
GT_E
<
T
>
(
in_w
,
width
-
0.5
)
||
GT
_E
<
T
>
(
-
0.5
,
in_h
)
||
GT_E
<
T
>
(
in_h
,
height
-
0.5
))
{
// empty
val
[
0
]
=
0.0
;
return
;
}
if
(
GT
<
T
>
(
0
,
in_w
))
{
if
(
GT
_E
<
T
>
(
0
,
in_w
))
{
in_w
=
0
;
}
if
(
GT
<
T
>
(
0
,
in_h
))
{
if
(
GT
_E
<
T
>
(
0
,
in_h
))
{
in_h
=
0
;
}
...
...
@@ -301,10 +301,10 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
T
in_w
,
in_h
;
get_source_coords
<
T
>
(
matrix
,
out_w
,
out_h
,
&
in_w
,
&
in_h
);
if
(
in_quad
<
T
>
(
in_w
,
in_h
,
roi_x
,
roi_y
))
{
if
(
GT
<
T
>
(
-
0.5
,
in_w
)
||
GT
<
T
>
(
in_w
,
static_cast
<
T
>
(
in_width
-
0.5
))
||
GT
<
T
>
(
-
0.5
,
in_h
)
||
GT
<
T
>
(
in_h
,
static_cast
<
T
>
(
in_height
-
0.5
)))
{
if
(
GT
_E
<
T
>
(
-
0.5
,
in_w
)
||
GT
_E
<
T
>
(
in_w
,
static_cast
<
T
>
(
in_width
-
0.5
))
||
GT
_E
<
T
>
(
-
0.5
,
in_h
)
||
GT
_E
<
T
>
(
in_h
,
static_cast
<
T
>
(
in_height
-
0.5
)))
{
output_data
[
out_index
]
=
0.0
;
mask_data
[(
n
*
transformed_height
+
out_h
)
*
transformed_width
+
out_w
]
=
0
;
...
...
@@ -330,15 +330,15 @@ class CPUROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
template
<
typename
T
>
T
get_feature_gradient
(
T
xs
,
T
ys
,
int
w
,
int
h
,
const
int
width
,
const
int
height
)
{
if
(
GT
<
T
>
(
-
0.5
,
xs
)
||
GT
<
T
>
(
xs
,
width
-
0.5
)
||
GT
<
T
>
(
-
0.5
,
ys
)
||
GT
<
T
>
(
ys
,
height
-
0.5
))
{
if
(
GT
_E
<
T
>
(
-
0.5
,
xs
)
||
GT_E
<
T
>
(
xs
,
width
-
0.5
)
||
GT_E
<
T
>
(
-
0.5
,
ys
)
||
GT
_E
<
T
>
(
ys
,
height
-
0.5
))
{
return
0
;
}
if
(
GT
<
T
>
(
0
,
xs
))
{
if
(
GT
_E
<
T
>
(
0
,
xs
))
{
xs
=
0
;
}
if
(
GT
<
T
>
(
0
,
ys
))
{
if
(
GT
_E
<
T
>
(
0
,
ys
))
{
ys
=
0
;
}
...
...
@@ -441,10 +441,10 @@ class CPUROIPerspectiveTransformGradOpKernel : public framework::OpKernel<T> {
T
src_h
;
get_source_coords
<
T
>
(
matrix
,
out_w
,
out_h
,
&
src_w
,
&
src_h
);
if
(
in_quad
<
T
>
(
src_w
,
src_h
,
roi_x
,
roi_y
))
{
if
(
GT
<
T
>
(
-
0.5
,
src_w
)
||
GT
<
T
>
(
src_w
,
static_cast
<
T
>
(
in_width
-
0.5
))
||
GT
<
T
>
(
-
0.5
,
src_h
)
||
GT
<
T
>
(
src_h
,
static_cast
<
T
>
(
in_height
-
0.5
)))
{
if
(
GT
_E
<
T
>
(
-
0.5
,
src_w
)
||
GT
_E
<
T
>
(
src_w
,
static_cast
<
T
>
(
in_width
-
0.5
))
||
GT
_E
<
T
>
(
-
0.5
,
src_h
)
||
GT
_E
<
T
>
(
src_h
,
static_cast
<
T
>
(
in_height
-
0.5
)))
{
continue
;
}
T
weight
=
get_feature_gradient
<
T
>
(
src_w
,
src_h
,
in_w
,
in_h
,
...
...
paddle/fluid/operators/detection/roi_perspective_transform_op.cu
浏览文件 @
c8e49be2
...
...
@@ -120,16 +120,16 @@ __device__ void bilinear_interpolate(const T* in_data, const int channels,
int
out_idx
,
int
*
out2in_idx
,
T
*
out2in_w
)
{
// Deal with cases that source coords are out of feature map boundary
if
(
GT
<
T
>
(
-
0.5
,
in_w
)
||
GT
<
T
>
(
in_w
,
width
-
0.5
)
||
GT
<
T
>
(
-
0.5
,
in_h
)
||
GT
<
T
>
(
in_h
,
height
-
0.5
))
{
if
(
GT
_E
<
T
>
(
-
0.5
,
in_w
)
||
GT_E
<
T
>
(
in_w
,
width
-
0.5
)
||
GT
_E
<
T
>
(
-
0.5
,
in_h
)
||
GT_E
<
T
>
(
in_h
,
height
-
0.5
))
{
val
[
0
]
=
0.0
;
return
;
}
if
(
GT
<
T
>
(
0
,
in_w
))
{
if
(
GT
_E
<
T
>
(
0
,
in_w
))
{
in_w
=
0
;
}
if
(
GT
<
T
>
(
0
,
in_h
))
{
if
(
GT
_E
<
T
>
(
0
,
in_h
))
{
in_h
=
0
;
}
...
...
@@ -284,7 +284,6 @@ __global__ void RoiTransformKernel(const float* input_data,
int
*
mask
,
T
*
transform_matrix
)
{
int
output_size
=
num_rois
*
transformed_height
*
transformed_width
*
channels
;
CUDA_1D_KERNEL_LOOP
(
index
,
output_size
)
{
// (n, c, out_h, out_w) is an element in the transformed output
int
out_w
=
idx4_4
(
index
,
num_rois
,
channels
,
transformed_height
,
...
...
@@ -318,8 +317,10 @@ __global__ void RoiTransformKernel(const float* input_data,
get_source_coords
<
T
>
(
matrix
,
out_w
,
out_h
,
&
in_w
,
&
in_h
);
if
(
in_quad
<
T
>
(
in_w
,
in_h
,
roi_x
,
roi_y
))
{
if
(
GT
<
T
>
(
-
0.5
,
in_w
)
||
GT
<
T
>
(
in_w
,
static_cast
<
T
>
(
in_width
-
0.5
))
||
GT
<
T
>
(
-
0.5
,
in_h
)
||
GT
<
T
>
(
in_h
,
static_cast
<
T
>
(
in_height
-
0.5
)))
{
if
(
GT_E
<
T
>
(
-
0.5
,
in_w
)
||
GT_E
<
T
>
(
in_w
,
static_cast
<
T
>
(
in_width
-
0.5
))
||
GT_E
<
T
>
(
-
0.5
,
in_h
)
||
GT_E
<
T
>
(
in_h
,
static_cast
<
T
>
(
in_height
-
0.5
)))
{
// Skip if source coords is not in input image
output_data
[
index
]
=
0.0
;
mask
[(
n
*
transformed_height
+
out_h
)
*
transformed_width
+
out_w
]
=
0
;
...
...
@@ -409,15 +410,15 @@ class CUDAROIPerspectiveTransformOpKernel : public framework::OpKernel<T> {
template
<
typename
T
>
__device__
T
get_feature_gradient
(
T
xs
,
T
ys
,
int
w
,
int
h
,
const
int
width
,
const
int
height
)
{
if
(
GT
<
T
>
(
-
0.5
,
xs
)
||
GT
<
T
>
(
xs
,
width
-
0.5
)
||
GT
<
T
>
(
-
0.5
,
ys
)
||
GT
<
T
>
(
ys
,
height
-
0.5
))
{
if
(
GT
_E
<
T
>
(
-
0.5
,
xs
)
||
GT_E
<
T
>
(
xs
,
width
-
0.5
)
||
GT_E
<
T
>
(
-
0.5
,
ys
)
||
GT
_E
<
T
>
(
ys
,
height
-
0.5
))
{
return
0
;
}
if
(
GT
<
T
>
(
0
,
xs
))
{
if
(
GT
_E
<
T
>
(
0
,
xs
))
{
xs
=
0
;
}
if
(
GT
<
T
>
(
0
,
ys
))
{
if
(
GT
_E
<
T
>
(
0
,
ys
))
{
ys
=
0
;
}
...
...
python/paddle/fluid/tests/unittests/test_roi_perspective_transform_op.py
浏览文件 @
c8e49be2
...
...
@@ -135,13 +135,13 @@ def bilinear_interpolate(in_data, in_n, in_c, in_w, in_h):
height
=
in_data
.
shape
[
2
]
width
=
in_data
.
shape
[
3
]
if
gt
(
-
0.5
,
in_w
)
or
gt
(
in_w
,
width
-
0.5
)
or
gt
(
-
0.5
,
in_h
)
or
gt
(
if
gt
_e
(
-
0.5
,
in_w
)
or
gt_e
(
in_w
,
width
-
0.5
)
or
gt_e
(
-
0.5
,
in_h
)
or
gt_e
(
in_h
,
height
-
0.5
):
return
0.0
if
gt
(
0
,
in_w
):
if
gt
_e
(
0
,
in_w
):
in_w
=
0
if
gt
(
0
,
in_h
):
if
gt
_e
(
0
,
in_h
):
in_h
=
0
in_w_floor
=
floor
(
in_w
)
...
...
@@ -216,9 +216,9 @@ def roi_transform(in_data, rois, rois_lod, transformed_height,
for
out_w
in
range
(
transformed_width
):
in_w
,
in_h
=
get_source_coords
(
transform_matrix
,
out_w
,
out_h
)
if
in_quad
(
in_w
,
in_h
,
roi_x
,
roi_y
)
and
gt
_e
(
in_w
,
-
0.5
)
and
lt_e
(
in_w
,
in_width
-
0.5
)
and
gt_e
(
in_h
,
-
0.5
)
and
lt_e
(
in_h
,
in_height
-
0.5
):
if
in_quad
(
in_w
,
in_h
,
roi_x
,
roi_y
)
and
gt
(
in_w
,
-
0.5
)
and
gt
(
in_width
-
0.5
,
in_w
)
and
gt
(
in_h
,
-
0.5
)
and
gt
(
in_height
-
0.5
,
in_h
):
out
[
n
][
c
][
out_h
][
out_w
]
=
bilinear_interpolate
(
in_data
,
image_id
,
c
,
in_w
,
in_h
)
mask
[
n
][
0
][
out_h
][
out_w
]
=
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录