Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d0b601c4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d0b601c4
编写于
11月 15, 2017
作者:
M
Markus Kliegl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
address PR feedback
上级
42dd5da0
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
20 addition
and
17 deletion
+20
-17
paddle/operators/conv_shift_op.cu
paddle/operators/conv_shift_op.cu
+20
-17
未找到文件。
paddle/operators/conv_shift_op.cu
浏览文件 @
d0b601c4
...
...
@@ -13,6 +13,7 @@
limitations under the License. */
#include "paddle/operators/conv_shift_op.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
...
...
@@ -33,9 +34,9 @@ inline int DivUp(int x, int y) { return (x + y - 1) / y; }
// y is fairly small. For large y, it would probably be more efficient
// to also tile across y.
template
<
typename
T
>
__global__
void
ConvShiftForward
(
const
T
*
x
,
const
T
*
y
,
T
*
out
,
int
x_width
,
int
y_width
,
int
y_half_width
,
int
batch_size
)
{
__global__
void
ConvShiftForward
(
const
T
*
x
,
const
T
*
y
,
int
x_width
,
int
y_width
,
int
y_half_width
,
int
batch_size
,
T
*
out
)
{
extern
__shared__
T
mem
[];
int
tx
=
threadIdx
.
x
;
...
...
@@ -79,8 +80,9 @@ __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
// Compute x gradient - initial naive implementation with atomic add.
template
<
typename
T
>
__global__
void
ConvShiftGradX
(
const
T
*
dout
,
const
T
*
y
,
T
*
dx
,
int
x_width
,
int
y_width
,
int
y_half_width
,
int
batch_size
)
{
__global__
void
ConvShiftGradX
(
const
T
*
dout
,
const
T
*
y
,
int
x_width
,
int
y_width
,
int
y_half_width
,
int
batch_size
,
T
*
dx
)
{
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// x index
int
j
=
blockIdx
.
y
;
// y index
int
k
=
blockIdx
.
z
;
// batch index
...
...
@@ -94,8 +96,8 @@ __global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width,
// Compute y gradient - initial naive implementation with atomic add.
template
<
typename
T
>
__global__
void
ConvShiftDy
(
const
T
*
x
,
const
T
*
dout
,
T
*
dy
,
int
x
_width
,
int
y_
width
,
int
y_half_width
,
int
batch_size
)
{
__global__
void
ConvShiftDy
(
const
T
*
x
,
const
T
*
dout
,
int
x_width
,
int
y
_width
,
int
y_
half_width
,
int
batch_size
,
T
*
dy
)
{
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// x index
int
j
=
blockIdx
.
y
;
// y index
int
k
=
blockIdx
.
z
;
// batch index
...
...
@@ -133,7 +135,7 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
auto
stream
=
context
.
cuda_device_context
().
stream
();
ConvShiftForward
<
T
><<<
grid_dim
,
x_per_block
,
mem_per_block
,
stream
>>>
(
x_data
,
y_data
,
out_data
,
x_width
,
y_width
,
y_half_width
,
batch_size
);
x_data
,
y_data
,
x_width
,
y_width
,
y_half_width
,
batch_size
,
out_data
);
}
};
...
...
@@ -157,7 +159,8 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
int
y_width
=
Y
->
dims
()[
1
];
int
y_half_width
=
(
y_width
-
1
)
/
2
;
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
&
device_ctx
=
context
.
cuda_device_context
();
math
::
SetConstant
<
platform
::
GPUPlace
,
T
>
zero
;
const
int
x_per_block
=
256
;
int
num_x_blocks
=
DivUp
(
x_width
,
x_per_block
);
...
...
@@ -165,17 +168,17 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
if
(
dX
)
{
T
*
dx_data
=
dX
->
mutable_data
<
T
>
(
context
.
GetPlace
());
cudaMemsetAsync
(
dx_data
,
0
,
dX
->
numel
()
*
sizeof
(
T
),
stream
);
ConvShiftGradX
<
T
><<<
grid_dim
,
x_per_block
,
0
,
stream
>>>
(
dout_data
,
y_data
,
dx_data
,
x_width
,
y_width
,
y_half_width
,
batch_size
);
zero
(
device_ctx
,
dX
,
static_cast
<
T
>
(
0.0
)
);
ConvShiftGradX
<
T
><<<
grid_dim
,
x_per_block
,
0
,
device_ctx
.
stream
()
>>>
(
dout_data
,
y_data
,
x_width
,
y_width
,
y_half_width
,
batch_size
,
dx_data
);
}
if
(
dY
)
{
T
*
dy_data
=
dY
->
mutable_data
<
T
>
(
context
.
GetPlace
());
cudaMemsetAsync
(
dy_data
,
0
,
dY
->
numel
()
*
sizeof
(
T
),
stream
);
ConvShiftDy
<
T
><<<
grid_dim
,
x_per_block
,
0
,
stream
>>>
(
x_data
,
dout_data
,
dy_data
,
x_width
,
y_width
,
y_half_width
,
batch_size
);
zero
(
device_ctx
,
dY
,
static_cast
<
T
>
(
0.0
)
);
ConvShiftDy
<
T
><<<
grid_dim
,
x_per_block
,
0
,
device_ctx
.
stream
()
>>>
(
x_data
,
dout_data
,
x_width
,
y_width
,
y_half_width
,
batch_size
,
dy_data
);
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录