Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5eec8cf5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
5eec8cf5
编写于
12月 10, 2019
作者:
W
wangchaochaohu
提交者:
GitHub
12月 10, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the mean grad OP performance improvement test=develop (#21658)
上级
29f64c8c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
9 addition
and
6 deletion
+9
-6
paddle/fluid/operators/mean_op.cu
paddle/fluid/operators/mean_op.cu
+9
-6
未找到文件。
paddle/fluid/operators/mean_op.cu
浏览文件 @
5eec8cf5
...
...
@@ -31,10 +31,11 @@ struct DivideFunctor {
};
template
<
typename
T
>
__global__
void
MeanRunKernel
(
const
T
in_data
,
T
*
out_data
,
int
N
)
{
__global__
void
MeanRunKernel
(
const
T
*
in_data
,
T
*
out_data
,
int
N
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
T
data
=
in_data
[
0
];
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
out_data
[
idx
]
=
in_
data
/
(
static_cast
<
T
>
(
N
));
out_data
[
idx
]
=
data
/
(
static_cast
<
T
>
(
N
));
}
}
...
...
@@ -85,7 +86,7 @@ class MeanCUDAGradKernel : public framework::OpKernel<T> {
auto
IG
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
IG
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
in_data
=
OG
[
0
]
;
auto
in_data
=
OG
->
data
<
T
>
()
;
auto
size_prob
=
IG
->
numel
();
auto
out_data
=
IG
->
data
<
T
>
();
int
threads
=
512
;
...
...
@@ -105,6 +106,8 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
MeanCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
MeanCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
mean_grad
,
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
mean_grad
,
ops
::
MeanCUDAGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MeanCUDAGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
MeanCUDAGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录