Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e21e5646
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e21e5646
编写于
10月 10, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix atomicAdd -> CudaAtomicAdd
上级
6c6474cb
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
4 addition
and
8 deletion
+4
-8
paddle/operators/math/pooling.cu
paddle/operators/math/pooling.cu
+4
-8
未找到文件。
paddle/operators/math/pooling.cu
浏览文件 @
e21e5646
...
...
@@ -144,7 +144,7 @@ __global__ void KernelMaxPool2DGrad(
if
(
maxIndex
!=
-
1
)
{
// atomic add
a
tomicAdd
(
input_grad
+
maxIndex
,
output_grad
[
index
]);
platform
::
CudaA
tomicAdd
(
input_grad
+
maxIndex
,
output_grad
[
index
]);
}
}
}
...
...
@@ -278,9 +278,7 @@ class MaxPool2dGradFunctor<platform::GPUPlace, T> {
};
template
class
MaxPool2dGradFunctor
<
platform
::
GPUPlace
,
float
>;
// template class MaxPool2dGradFunctor<platform::GPUPlace, double>; // The
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
template
class
MaxPool2dGradFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
Pool2dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
MaxPool
<
float
>,
float
>
;
...
...
@@ -453,7 +451,7 @@ __global__ void KernelMaxPool3DGrad(
}
if
(
maxIdx
!=
-
1
)
{
// atomic add
a
tomicAdd
(
input_grad
+
maxIdx
,
output_grad
[
index
]);
platform
::
CudaA
tomicAdd
(
input_grad
+
maxIdx
,
output_grad
[
index
]);
}
}
}
...
...
@@ -609,9 +607,7 @@ class MaxPool3dGradFunctor<platform::GPUPlace, T> {
};
template
class
MaxPool3dGradFunctor
<
platform
::
GPUPlace
,
float
>;
// template class MaxPool3dGradFunctor<platform::GPUPlace, double>; // The
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
template
class
MaxPool3dGradFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
Pool3dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
MaxPool
<
float
>,
float
>
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录