Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
1d7f91e8
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1d7f91e8
编写于
5月 18, 2018
作者:
B
baiyfbupt
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix roi_pool gpu_backward kernel
上级
f79779f0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
4 addition
and
4 deletion
+4
-4
paddle/fluid/operators/roi_pool_op.cu
paddle/fluid/operators/roi_pool_op.cu
+4
-4
未找到文件。
paddle/fluid/operators/roi_pool_op.cu
浏览文件 @
1d7f91e8
...
...
@@ -101,10 +101,10 @@ __global__ void GPUROIPoolBackward(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
pw
=
i
ndex
%
pooled_width
;
int
ph
=
(
i
ndex
/
pooled_width
)
%
pooled_height
;
int
c
=
(
i
ndex
/
pooled_width
/
pooled_height
)
%
channels
;
int
n
=
i
ndex
/
pooled_width
/
pooled_height
/
channels
;
int
pw
=
i
%
pooled_width
;
int
ph
=
(
i
/
pooled_width
)
%
pooled_height
;
int
c
=
(
i
/
pooled_width
/
pooled_height
)
%
channels
;
int
n
=
i
/
pooled_width
/
pooled_height
/
channels
;
int
roi_batch_ind
=
roi_batch_id_data
[
n
];
int
input_offset
=
(
roi_batch_ind
*
channels
+
c
)
*
height
*
width
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录