Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d3ac561d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d3ac561d
编写于
7月 25, 2019
作者:
B
Bai Yifan
提交者:
GitHub
7月 25, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix deformable_conv_op compile error, test=develop (#18793)
上级
9ecd8ee7
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
30 addition
and
30 deletion
+30
-30
paddle/fluid/operators/deformable_conv_op.cu
paddle/fluid/operators/deformable_conv_op.cu
+30
-30
未找到文件。
paddle/fluid/operators/deformable_conv_op.cu
浏览文件 @
d3ac561d
...
...
@@ -200,6 +200,36 @@ __device__ T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height,
return
weight
;
}
template
<
typename
T
>
__device__
T
DmcnIm2colBilinear
(
const
T
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
T
h
,
T
w
)
{
int
h_low
=
floor
(
h
);
int
w_low
=
floor
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
T
lh
=
h
-
h_low
;
T
lw
=
w
-
w_low
;
T
hh
=
1
-
lh
,
hw
=
1
-
lw
;
T
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
v1
=
bottom_data
[
h_low
*
data_width
+
w_low
];
T
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
v2
=
bottom_data
[
h_low
*
data_width
+
w_high
];
T
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
v3
=
bottom_data
[
h_high
*
data_width
+
w_low
];
T
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
v4
=
bottom_data
[
h_high
*
data_width
+
w_high
];
T
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
T
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
typename
T
>
__global__
void
ModulatedDeformableCol2imCoordGpuKernel
(
const
int
nthreads
,
const
T
*
data_col
,
const
T
*
data_im
,
...
...
@@ -315,36 +345,6 @@ inline void ModulatedDeformableCol2imCoord(
deformable_groups
,
col_shape
[
2
],
col_shape
[
3
],
grad_offset
,
grad_mask
);
}
template
<
typename
T
>
__device__
T
DmcnIm2colBilinear
(
const
T
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
T
h
,
T
w
)
{
int
h_low
=
floor
(
h
);
int
w_low
=
floor
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
T
lh
=
h
-
h_low
;
T
lw
=
w
-
w_low
;
T
hh
=
1
-
lh
,
hw
=
1
-
lw
;
T
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
v1
=
bottom_data
[
h_low
*
data_width
+
w_low
];
T
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
v2
=
bottom_data
[
h_low
*
data_width
+
w_high
];
T
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
v3
=
bottom_data
[
h_high
*
data_width
+
w_low
];
T
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
v4
=
bottom_data
[
h_high
*
data_width
+
w_high
];
T
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
T
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
typename
T
>
__global__
void
ModulatedDeformableIm2colGpuKernel
(
const
int
nthreads
,
const
T
*
data_im
,
const
T
*
data_offset
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录