Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0112c5d6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0112c5d6
编写于
11月 22, 2017
作者:
S
sweetsky0901
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format code
上级
47bd0bb6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
15 addition
and
16 deletion
+15
-16
paddle/operators/math/unpooling.cc
paddle/operators/math/unpooling.cc
+0
-1
paddle/operators/math/unpooling.cu
paddle/operators/math/unpooling.cu
+15
-15
未找到文件。
paddle/operators/math/unpooling.cc
浏览文件 @
0112c5d6
...
...
@@ -69,7 +69,6 @@ public:
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
int
input_feasize
=
input_height
*
input_width
;
int
output_feasize
=
output_height
*
output_width
;
const
T
*
indices_data
=
indices
.
data
<
T
>
();
...
...
paddle/operators/math/unpooling.cu
浏览文件 @
0112c5d6
...
...
@@ -29,21 +29,21 @@ __global__ void KernelUnpool2dMax(const int nthreads,
T
*
output_data
,
const
int
output_height
,
const
int
output_width
)
{
int
bsize
=
input_height
*
input_width
*
channels
;
int
csize
=
input_height
*
input_width
;
int
out_bsize
=
output_height
*
output_width
*
channels
;
int
out_csize
=
output_height
*
output_width
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
bidx
=
i
/
bsize
;
int
boffset
=
i
%
bsize
;
int
cidx
=
boffset
/
csize
;
int
out_offset
=
bidx
*
out_bsize
+
cidx
*
out_csize
;
int
out_index
=
indices_data
[
i
];
PADDLE_ASSERT
(
out_index
<
(
output_height
*
output_width
));
output_data
[
out_offset
+
out_index
]
=
input_data
[
i
];
}
int
bsize
=
input_height
*
input_width
*
channels
;
int
csize
=
input_height
*
input_width
;
int
out_bsize
=
output_height
*
output_width
*
channels
;
int
out_csize
=
output_height
*
output_width
;
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
offset
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
index
;
i
<
nthreads
;
i
+=
offset
)
{
int
bidx
=
i
/
bsize
;
int
boffset
=
i
%
bsize
;
int
cidx
=
boffset
/
csize
;
int
out_offset
=
bidx
*
out_bsize
+
cidx
*
out_csize
;
int
out_index
=
indices_data
[
i
];
PADDLE_ASSERT
(
out_index
<
(
output_height
*
output_width
));
output_data
[
out_offset
+
out_index
]
=
input_data
[
i
];
}
}
template
<
typename
T
>
__global__
void
KernelUnpool2dMaxGrad
(
const
int
nthreads
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录