Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c8bb6631
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c8bb6631
编写于
11月 27, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine roi_pool_op to avoid warning
上级
e6546baa
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
21 addition
and
28 deletion
+21
-28
paddle/operators/roi_pool_op.h
paddle/operators/roi_pool_op.h
+21
-28
未找到文件。
paddle/operators/roi_pool_op.h
100755 → 100644
浏览文件 @
c8bb6631
...
...
@@ -133,54 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
rois
=
ctx
.
Input
<
framework
::
Tensor
>
(
"ROIs"
);
auto
*
argmax
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Argmax"
);
auto
*
out_grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
x_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
in_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
pooled_height
=
ctx
.
Attr
<
int
>
(
"pooled_height"
);
auto
pooled_width
=
ctx
.
Attr
<
int
>
(
"pooled_width"
);
if
(
x_grad
)
{
int
channels
=
in
->
dims
()[
1
];
auto
in_stride
=
framework
::
stride
(
in
->
dims
());
auto
roi_stride
=
framework
::
stride
(
rois
->
dims
());
if
(
in_grad
)
{
const
int64_t
*
rois_data
=
rois
->
data
<
int64_t
>
();
int
rois_num
=
rois
->
dims
()[
0
]
;
T
*
x_grad_data
=
x
_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
out_grad_data
=
out_grad
->
data
<
T
>
()
;
const
int64_t
*
argmax_data
=
argmax
->
data
<
int64_t
>
();
T
*
in_grad_data
=
in
_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
set_zero
;
set_zero
(
ctx
.
device_context
(),
x
_grad
,
static_cast
<
T
>
(
0
));
set_zero
(
ctx
.
device_context
(),
in
_grad
,
static_cast
<
T
>
(
0
));
size_t
roi_offset
=
roi_stride
[
0
];
size_t
batch_offset
=
in_stride
[
0
];
size_t
channel_offset
=
in_stride
[
1
];
auto
in_stride
=
framework
::
stride
(
in
->
dims
());
auto
argmax_stride
=
framework
::
stride
(
argmax
->
dims
());
auto
roi_stride
=
framework
::
stride
(
rois
->
dims
());
auto
out_stride
=
framework
::
stride
(
out_grad
->
dims
());
const
T
*
out_grad_data
=
out_grad
->
data
<
T
>
();
size_t
pool_channel_offset
=
pooled_height
*
pooled_width
;
const
int64_t
*
argmax_data
=
argmax
->
data
<
int64_t
>
();
int
rois_num
=
rois
->
dims
()[
0
];
int
channels
=
in
->
dims
()[
1
];
for
(
size_
t
n
=
0
;
n
<
rois_num
;
++
n
)
{
size_
t
roi_batch_idx
=
rois_data
[
0
];
T
*
batch_grad_data
=
x_grad_data
+
batch_offset
*
roi_batch_idx
;
for
(
in
t
n
=
0
;
n
<
rois_num
;
++
n
)
{
in
t
roi_batch_idx
=
rois_data
[
0
];
T
*
batch_grad_data
=
in_grad_data
+
roi_batch_idx
*
in_stride
[
0
]
;
for
(
int
c
=
0
;
c
<
channels
;
++
c
)
{
for
(
int
ph
=
0
;
ph
<
pooled_height
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
pooled_width
;
++
pw
)
{
size_t
pool_index
=
ph
*
pooled_width
+
pw
;
int
pool_index
=
ph
*
pooled_width
+
pw
;
if
(
argmax_data
[
pool_index
]
>=
0
)
{
size_t
index
=
static_cast
<
size_t
>
(
argmax_data
[
pool_index
])
;
auto
index
=
argmax_data
[
pool_index
]
;
batch_grad_data
[
index
]
+=
out_grad_data
[
pool_index
];
}
}
}
batch_grad_data
+=
channel_offset
;
out_grad_data
+=
pool_channel_offset
;
argmax_data
+=
pool_channel_offset
;
batch_grad_data
+=
in_stride
[
1
]
;
out_grad_data
+=
out_stride
[
1
]
;
argmax_data
+=
argmax_stride
[
1
]
;
}
rois_data
+=
roi_
offset
;
rois_data
+=
roi_
stride
[
0
]
;
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录