Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
2d13462a
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2d13462a
编写于
9月 19, 2016
作者:
L
liaogang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix incompatible on CUDA atomicAdd operation
上级
4e37b226
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
44 addition
and
35 deletion
+44
-35
paddle/cuda/include/hl_device_functions.cuh
paddle/cuda/include/hl_device_functions.cuh
+29
-20
paddle/cuda/include/hl_gpu_lstm.cuh
paddle/cuda/include/hl_gpu_lstm.cuh
+3
-3
paddle/cuda/src/hl_cuda_lstm.cu
paddle/cuda/src/hl_cuda_lstm.cu
+3
-3
paddle/cuda/src/hl_cuda_matrix.cu
paddle/cuda/src/hl_cuda_matrix.cu
+2
-2
paddle/cuda/src/hl_cuda_sequence.cu
paddle/cuda/src/hl_cuda_sequence.cu
+1
-1
paddle/cuda/src/hl_cuda_sparse.cuh
paddle/cuda/src/hl_cuda_sparse.cuh
+5
-5
paddle/cuda/src/hl_table_apply.cu
paddle/cuda/src/hl_table_apply.cu
+1
-1
未找到文件。
paddle/cuda/include/hl_device_functions.cuh
浏览文件 @
2d13462a
...
...
@@ -16,28 +16,37 @@ limitations under the License. */
#ifndef HL_DEVICE_FUNCTIONS_CUH_
#define HL_DEVICE_FUNCTIONS_CUH_
namespace
hppl
{
static
__inline__
__device__
double
atomicAdd
(
double
*
address
,
double
val
)
{
// NOLINTNEXTLINE
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
,
assumed
;
// NOLINT
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
}
while
(
assumed
!=
old
);
return
__longlong_as_double
(
old
);
}
namespace
paddle
{
template
<
class
T
>
inline
__device__
T
paddleAtomicAdd
(
T
*
address
,
T
val
);
}
// namespace hppl
template
<
>
inline
__device__
float
paddleAtomicAdd
(
float
*
address
,
float
val
)
{
return
atomicAdd
(
address
,
val
);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
using
hppl
::
atomicAdd
;
template
<
>
inline
__device__
double
paddleAtomicAdd
(
double
*
address
,
double
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
return
atomicAdd
(
address
,
val
);
#else
// NOLINTNEXTLINE
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
,
assumed
;
// NOLINT
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
}
while
(
assumed
!=
old
);
return
__longlong_as_double
(
old
);
#endif
}
}
// namespace paddle
#endif
/* HL_DEVICE_FUNCTIONS_CUH_ */
paddle/cuda/include/hl_gpu_lstm.cuh
浏览文件 @
2d13462a
...
...
@@ -192,10 +192,10 @@ __global__ void KeLstmBackward(Op op,
if
(
isBatch
)
{
if
(
value
.
prevStateValue
)
{
if
(
grad
.
checkIgGrad
)
a
tomicAdd
(
grad
.
checkIgGrad
+
frameIdx
,
rCheckIGrad
);
if
(
grad
.
checkFgGrad
)
a
tomicAdd
(
grad
.
checkFgGrad
+
frameIdx
,
rCheckFGrad
);
if
(
grad
.
checkIgGrad
)
paddle
::
paddleA
tomicAdd
(
grad
.
checkIgGrad
+
frameIdx
,
rCheckIGrad
);
if
(
grad
.
checkFgGrad
)
paddle
::
paddleA
tomicAdd
(
grad
.
checkFgGrad
+
frameIdx
,
rCheckFGrad
);
}
if
(
grad
.
checkOgGrad
)
a
tomicAdd
(
grad
.
checkOgGrad
+
frameIdx
,
rCheckOGrad
);
if
(
grad
.
checkOgGrad
)
paddle
::
paddleA
tomicAdd
(
grad
.
checkOgGrad
+
frameIdx
,
rCheckOGrad
);
}
else
{
if
(
value
.
prevStateValue
)
{
if
(
grad
.
checkIgGrad
)
grad
.
checkIgGrad
[
frameIdx
]
+=
rCheckIGrad
;
...
...
paddle/cuda/src/hl_cuda_lstm.cu
浏览文件 @
2d13462a
...
...
@@ -564,11 +564,11 @@ __global__ void KeLstmBackward(real *gateValue,
/* TODO: Temporary save & merger in another kernel */
if
(
frameIdy
==
1
)
{
if
(
checkIgGrad
)
a
tomicAdd
(
checkIgGrad
+
frameIdx
,
rCheckGrad
);
if
(
checkIgGrad
)
paddle
::
paddleA
tomicAdd
(
checkIgGrad
+
frameIdx
,
rCheckGrad
);
}
else
if
(
frameIdy
==
2
)
{
if
(
checkFgGrad
)
a
tomicAdd
(
checkFgGrad
+
frameIdx
,
rCheckGrad
);
if
(
checkFgGrad
)
paddle
::
paddleA
tomicAdd
(
checkFgGrad
+
frameIdx
,
rCheckGrad
);
}
else
if
(
frameIdy
==
3
)
{
if
(
checkOgGrad
)
a
tomicAdd
(
checkOgGrad
+
frameIdx
,
rCheckGrad
);
if
(
checkOgGrad
)
paddle
::
paddleA
tomicAdd
(
checkOgGrad
+
frameIdx
,
rCheckGrad
);
}
}
...
...
paddle/cuda/src/hl_cuda_matrix.cu
浏览文件 @
2d13462a
...
...
@@ -623,7 +623,7 @@ __global__ void KeCosSimDerivative(real* grad,
prevGradY
[
index
]
+=
scale
*
grad
[
ty
]
*
prevOutX
[
index
]
*
reciprocal
;
}
else
{
a
tomicAdd
(
prevGradY
+
index
,
paddle
::
paddleA
tomicAdd
(
prevGradY
+
index
,
scale
*
grad
[
ty
]
*
prevOutX
[
index
]
*
reciprocal
);
}
}
...
...
@@ -640,7 +640,7 @@ __global__ void KeCosSimDerivative(real* grad,
(
prevOutX
[
index
]
*
reciprocalXY
-
prevOutY
[
index
]
*
reciprocalSquareSumY
);
}
else
{
a
tomicAdd
(
prevGradY
+
index
,
output
[
ty
]
*
grad
[
ty
]
*
paddle
::
paddleA
tomicAdd
(
prevGradY
+
index
,
output
[
ty
]
*
grad
[
ty
]
*
(
prevOutX
[
index
]
*
reciprocalXY
-
prevOutY
[
index
]
*
reciprocalSquareSumY
));
}
...
...
paddle/cuda/src/hl_cuda_sequence.cu
浏览文件 @
2d13462a
...
...
@@ -362,7 +362,7 @@ __global__ void KeMatrixAddRows(real* output,
if
(
AddRow
==
0
)
{
outputData
[
i
]
+=
tableData
[
i
];
}
else
{
a
tomicAdd
(
&
tableData
[
i
],
outputData
[
i
]);
paddle
::
paddleA
tomicAdd
(
&
tableData
[
i
],
outputData
[
i
]);
}
}
}
...
...
paddle/cuda/src/hl_cuda_sparse.cuh
浏览文件 @
2d13462a
...
...
@@ -280,7 +280,7 @@ __global__ void KeSMatrixCscMulDense(real *C_d,
if
(
index_n_t
<
dimN
)
{
real
tmp
;
tmp
=
alpha
*
a_r
*
b_r
[
n
];
a
tomicAdd
(
C_d_r
,
tmp
);
paddle
::
paddleA
tomicAdd
(
C_d_r
,
tmp
);
C_d_r
+=
CU_CSC_MUL_DENSE_THREAD_X
;
index_n_t
+=
CU_CSC_MUL_DENSE_THREAD_X
;
}
...
...
@@ -328,7 +328,7 @@ __global__ void KeSMatrixCscMulDense(real *C_d,
if
(
index_n_t
<
dimN
)
{
real
tmp
;
tmp
=
alpha
*
a_r
*
b_r
[
n
];
a
tomicAdd
(
C_d_r
,
tmp
);
paddle
::
paddleA
tomicAdd
(
C_d_r
,
tmp
);
C_d_r
+=
CU_CSC_MUL_DENSE_THREAD_X
;
index_n_t
+=
CU_CSC_MUL_DENSE_THREAD_X
;
}
...
...
@@ -629,7 +629,7 @@ __global__ void KeSMatrixDenseMulCsr(real *C_d,
for
(
int
n
=
0
;
n
<
CU_DM_CSR_N
;
n
++
)
{
if
(
index_m_t
++
<
dimM
)
{
tmp
=
alpha
*
b_r
*
a_r
[
n
];
a
tomicAdd
(
C_d_r
,
tmp
);
paddle
::
paddleA
tomicAdd
(
C_d_r
,
tmp
);
C_d_r
+=
dimN
;
}
}
...
...
@@ -660,7 +660,7 @@ __global__ void KeSMatrixDenseMulCsr(real *C_d,
for
(
int
n
=
0
;
n
<
CU_DM_CSR_N
;
n
++
)
{
if
(
index_m_t
++
<
dimM
)
{
tmp
=
alpha
*
b_r
*
a_r
[
n
];
a
tomicAdd
(
C_d_r
,
tmp
);
paddle
::
paddleA
tomicAdd
(
C_d_r
,
tmp
);
C_d_r
+=
dimN
;
}
}
...
...
@@ -912,7 +912,7 @@ __global__ void KeSMatrixCsrColumnSum(real* a_val, real* csr_val,
for
(
int
idx
=
gid
;
idx
<
dimNNZ
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
colIdx
=
csr_col
[
idx
];
real
val
=
csr_val
[
idx
];
a
tomicAdd
(
a_val
+
colIdx
,
val
);
paddle
::
paddleA
tomicAdd
(
a_val
+
colIdx
,
val
);
}
}
...
...
paddle/cuda/src/hl_table_apply.cu
浏览文件 @
2d13462a
...
...
@@ -35,7 +35,7 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
real
*
tab
=
table
+
tableId
*
ldt
;
for
(
int
i
=
idx
;
i
<
dim
;
i
+=
blockDimX
)
{
if
(
AddRow
)
{
a
tomicAdd
(
&
tab
[
i
],
out
[
i
]);
paddle
::
paddleA
tomicAdd
(
&
tab
[
i
],
out
[
i
]);
}
else
{
out
[
i
]
+=
tab
[
i
];
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录