Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1d4fa243
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1d4fa243
编写于
8月 04, 2017
作者:
L
liaogang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ClangFormat for proto and cuda
上级
6512893b
变更
39
展开全部
隐藏空白更改
内联
并排
Showing
39 changed file
with
3661 addition
and
2920 deletion
+3661
-2920
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
paddle/cuda/src/hl_batch_transpose.cu
paddle/cuda/src/hl_batch_transpose.cu
+7
-9
paddle/cuda/src/hl_cuda_aggregate.cu
paddle/cuda/src/hl_cuda_aggregate.cu
+61
-101
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+275
-134
paddle/cuda/src/hl_cuda_lstm.cu
paddle/cuda/src/hl_cuda_lstm.cu
+331
-159
paddle/cuda/src/hl_cuda_matrix.cu
paddle/cuda/src/hl_cuda_matrix.cu
+147
-196
paddle/cuda/src/hl_cuda_sequence.cu
paddle/cuda/src/hl_cuda_sequence.cu
+96
-88
paddle/cuda/src/hl_cuda_sparse.cu
paddle/cuda/src/hl_cuda_sparse.cu
+475
-509
paddle/cuda/src/hl_perturbation_util.cu
paddle/cuda/src/hl_perturbation_util.cu
+104
-45
paddle/cuda/src/hl_table_apply.cu
paddle/cuda/src/hl_table_apply.cu
+35
-33
paddle/cuda/src/hl_top_k.cu
paddle/cuda/src/hl_top_k.cu
+127
-114
paddle/framework/attr_type.proto
paddle/framework/attr_type.proto
+7
-7
paddle/framework/op_desc.proto
paddle/framework/op_desc.proto
+17
-17
paddle/framework/op_proto.proto
paddle/framework/op_proto.proto
+72
-70
paddle/function/ContextProjectionOpGpu.cu
paddle/function/ContextProjectionOpGpu.cu
+70
-56
paddle/function/CosSimOpGpu.cu
paddle/function/CosSimOpGpu.cu
+34
-26
paddle/function/CropOpGpu.cu
paddle/function/CropOpGpu.cu
+59
-25
paddle/function/CrossMapNormalOpGpu.cu
paddle/function/CrossMapNormalOpGpu.cu
+46
-25
paddle/function/DepthwiseConvOpGpu.cu
paddle/function/DepthwiseConvOpGpu.cu
+253
-218
paddle/function/Im2ColOpGpu.cu
paddle/function/Im2ColOpGpu.cu
+150
-106
paddle/function/MulOpGpu.cu
paddle/function/MulOpGpu.cu
+1
-1
paddle/function/PadOpGpu.cu
paddle/function/PadOpGpu.cu
+49
-15
paddle/function/RowConvOpGpu.cu
paddle/function/RowConvOpGpu.cu
+87
-68
paddle/gserver/layers/GruCompute.cu
paddle/gserver/layers/GruCompute.cu
+4
-3
paddle/gserver/layers/LstmCompute.cu
paddle/gserver/layers/LstmCompute.cu
+38
-17
paddle/math/BaseMatrix.cu
paddle/math/BaseMatrix.cu
+619
-366
paddle/math/TrainingAlgorithmOp.cu
paddle/math/TrainingAlgorithmOp.cu
+32
-33
paddle/math/tests/test_Tensor.cu
paddle/math/tests/test_Tensor.cu
+167
-170
paddle/math/tests/test_lazyAssign.cu
paddle/math/tests/test_lazyAssign.cu
+40
-34
paddle/operators/softmax_op.cu
paddle/operators/softmax_op.cu
+2
-1
paddle/trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.proto
...a_provider_wrapper_dir/test_pydata_provider_wrapper.proto
+0
-0
proto/DataConfig.proto
proto/DataConfig.proto
+27
-26
proto/DataFormat.proto
proto/DataFormat.proto
+22
-16
proto/ModelConfig.proto
proto/ModelConfig.proto
+57
-57
proto/OptimizerConfig.proto
proto/OptimizerConfig.proto
+36
-36
proto/ParameterConfig.proto
proto/ParameterConfig.proto
+23
-22
proto/ParameterServerConfig.proto
proto/ParameterServerConfig.proto
+10
-13
proto/ParameterService.proto
proto/ParameterService.proto
+37
-64
proto/TrainerConfig.proto
proto/TrainerConfig.proto
+43
-39
未找到文件。
.pre-commit-config.yaml
浏览文件 @
1d4fa243
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
description
:
Format files with ClangFormat.
description
:
Format files with ClangFormat.
entry
:
clang-format -i
entry
:
clang-format -i
language
:
system
language
:
system
files
:
\.(c|cc|cxx|cpp|
h|hpp|hxx
)$
files
:
\.(c|cc|cxx|cpp|
cu|h|hpp|hxx|proto
)$
-
repo
:
https://github.com/PaddlePaddle/pre-commit-golang
-
repo
:
https://github.com/PaddlePaddle/pre-commit-golang
sha
:
8337620115c25ff8333f1b1a493bd031049bd7c0
sha
:
8337620115c25ff8333f1b1a493bd031049bd7c0
hooks
:
hooks
:
...
...
paddle/cuda/src/hl_batch_transpose.cu
浏览文件 @
1d4fa243
...
@@ -12,17 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,17 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "hl_batch_transpose.h"
#include "hl_base.h"
#include "hl_base.h"
#include "hl_batch_transpose.h"
const
int
TILE_DIM
=
64
;
const
int
TILE_DIM
=
64
;
const
int
BLOCK_ROWS
=
16
;
const
int
BLOCK_ROWS
=
16
;
// No bank-conflict transpose for a batch of data.
// No bank-conflict transpose for a batch of data.
__global__
void
batchTransposeNoBankConflicts
(
real
*
odata
,
__global__
void
batchTransposeNoBankConflicts
(
const
real
*
idata
,
real
*
odata
,
const
real
*
idata
,
int
numSamples
,
int
width
,
int
height
)
{
int
numSamples
,
int
width
,
int
height
)
{
__shared__
float
tile
[
TILE_DIM
][
TILE_DIM
+
1
];
__shared__
float
tile
[
TILE_DIM
][
TILE_DIM
+
1
];
const
int
x
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
x
;
const
int
x
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
x
;
...
@@ -50,12 +48,12 @@ __global__ void batchTransposeNoBankConflicts(real* odata,
...
@@ -50,12 +48,12 @@ __global__ void batchTransposeNoBankConflicts(real* odata,
newX
]
=
tile
[
threadIdx
.
x
][
j
];
newX
]
=
tile
[
threadIdx
.
x
][
j
];
}
}
void
batchTranspose
(
const
real
*
input
,
real
*
output
,
int
width
,
int
height
,
void
batchTranspose
(
int
batchSize
)
{
const
real
*
input
,
real
*
output
,
int
width
,
int
height
,
int
batchSize
)
{
dim3
dimBlock
(
TILE_DIM
,
BLOCK_ROWS
,
1
);
dim3
dimBlock
(
TILE_DIM
,
BLOCK_ROWS
,
1
);
dim3
dimGrid
(
DIVUP
(
width
,
TILE_DIM
),
DIVUP
(
height
,
TILE_DIM
),
batchSize
);
dim3
dimGrid
(
DIVUP
(
width
,
TILE_DIM
),
DIVUP
(
height
,
TILE_DIM
),
batchSize
);
batchTransposeNoBankConflicts
<<<
dimGrid
,
dimBlock
,
0
,
STREAM_DEFAULT
>>>
batchTransposeNoBankConflicts
<<<
dimGrid
,
dimBlock
,
0
,
STREAM_DEFAULT
>>>
(
(
output
,
input
,
batchSize
,
width
,
height
);
output
,
input
,
batchSize
,
width
,
height
);
CHECK_SYNC
(
"batchTranspose failed!"
);
CHECK_SYNC
(
"batchTranspose failed!"
);
}
}
paddle/cuda/src/hl_cuda_aggregate.cu
浏览文件 @
1d4fa243
...
@@ -12,27 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,27 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "hl_aggregate.h"
#include "hl_base.h"
#include "hl_base.h"
#include "hl_cuda.h"
#include "hl_cuda.h"
#include "hl_cuda.ph"
#include "hl_cuda.ph"
#include "hl_aggregate.h"
#include "hl_thread.ph"
#include "hl_matrix_base.cuh"
#include "hl_matrix_base.cuh"
#include "hl_thread.ph"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Logging.h"
/**
/**
* @brief matrix row operator.
* @brief matrix row operator.
*/
*/
template
<
class
Agg
,
int
blockSize
>
template
<
class
Agg
,
int
blockSize
>
__global__
void
KeMatrixRowOp
(
Agg
agg
,
__global__
void
KeMatrixRowOp
(
Agg
agg
,
real
*
E
,
real
*
Sum
,
int
dimN
)
{
real
*
E
,
real
*
Sum
,
int
dimN
)
{
__shared__
real
sum_s
[
blockSize
];
__shared__
real
sum_s
[
blockSize
];
int
cnt
=
(
dimN
+
blockSize
-
1
)
/
blockSize
;
int
cnt
=
(
dimN
+
blockSize
-
1
)
/
blockSize
;
int
rowId
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
;
int
rowId
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
;
int
index
=
rowId
*
dimN
;
int
index
=
rowId
*
dimN
;
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
lmt
=
tid
;
int
lmt
=
tid
;
...
@@ -44,7 +40,7 @@ __global__ void KeMatrixRowOp(Agg agg,
...
@@ -44,7 +40,7 @@ __global__ void KeMatrixRowOp(Agg agg,
sum_s
[
tid
]
=
tmp
;
sum_s
[
tid
]
=
tmp
;
__syncthreads
();
__syncthreads
();
for
(
int
stride
=
blockSize
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
for
(
int
stride
=
blockSize
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
if
(
tid
<
stride
)
{
if
(
tid
<
stride
)
{
sum_s
[
tid
]
=
agg
(
sum_s
[
tid
],
sum_s
[
tid
+
stride
]);
sum_s
[
tid
]
=
agg
(
sum_s
[
tid
],
sum_s
[
tid
+
stride
]);
}
}
...
@@ -58,29 +54,21 @@ __global__ void KeMatrixRowOp(Agg agg,
...
@@ -58,29 +54,21 @@ __global__ void KeMatrixRowOp(Agg agg,
}
}
template
<
class
Agg
>
template
<
class
Agg
>
void
hl_matrix_row_op
(
Agg
agg
,
void
hl_matrix_row_op
(
Agg
agg
,
real
*
A_d
,
real
*
C_d
,
int
dimM
,
int
dimN
)
{
real
*
A_d
,
real
*
C_d
,
int
dimM
,
int
dimN
)
{
int
blocksX
=
dimM
;
int
blocksX
=
dimM
;
int
blocksY
=
1
;
int
blocksY
=
1
;
dim3
threads
(
128
,
1
);
dim3
threads
(
128
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
dim3
grid
(
blocksX
,
blocksY
);
KeMatrixRowOp
<
Agg
,
128
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeMatrixRowOp
<
Agg
,
128
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
agg
,
A_d
,
C_d
,
dimN
);
agg
,
A_d
,
C_d
,
dimN
);
}
}
void
hl_matrix_row_sum
(
real
*
A_d
,
real
*
C_d
,
int
dimM
,
int
dimN
)
{
void
hl_matrix_row_sum
(
real
*
A_d
,
real
*
C_d
,
int
dimM
,
int
dimN
)
{
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
C_d
);
CHECK_NOTNULL
(
C_d
);
hl_matrix_row_op
(
aggregate
::
sum
(),
hl_matrix_row_op
(
aggregate
::
sum
(),
A_d
,
C_d
,
dimM
,
dimN
);
A_d
,
C_d
,
dimM
,
dimN
);
CHECK_SYNC
(
"hl_matrix_row_sum failed"
);
CHECK_SYNC
(
"hl_matrix_row_sum failed"
);
}
}
...
@@ -88,11 +76,7 @@ void hl_matrix_row_max(real *A_d, real *C_d, int dimM, int dimN) {
...
@@ -88,11 +76,7 @@ void hl_matrix_row_max(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
C_d
);
CHECK_NOTNULL
(
C_d
);
hl_matrix_row_op
(
aggregate
::
max
(),
hl_matrix_row_op
(
aggregate
::
max
(),
A_d
,
C_d
,
dimM
,
dimN
);
A_d
,
C_d
,
dimM
,
dimN
);
CHECK_SYNC
(
"hl_matrix_row_max failed"
);
CHECK_SYNC
(
"hl_matrix_row_max failed"
);
}
}
...
@@ -100,23 +84,16 @@ void hl_matrix_row_min(real *A_d, real *C_d, int dimM, int dimN) {
...
@@ -100,23 +84,16 @@ void hl_matrix_row_min(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
C_d
);
CHECK_NOTNULL
(
C_d
);
hl_matrix_row_op
(
aggregate
::
min
(),
hl_matrix_row_op
(
aggregate
::
min
(),
A_d
,
C_d
,
dimM
,
dimN
);
A_d
,
C_d
,
dimM
,
dimN
);
CHECK_SYNC
(
"hl_matrix_row_min failed"
);
CHECK_SYNC
(
"hl_matrix_row_min failed"
);
}
}
/**
/**
* @brief matrix column operator.
* @brief matrix column operator.
*/
*/
template
<
class
Agg
>
template
<
class
Agg
>
__global__
void
KeMatrixColumnOp
(
Agg
agg
,
__global__
void
KeMatrixColumnOp
(
real
*
E
,
Agg
agg
,
real
*
E
,
real
*
Sum
,
int
dimM
,
int
dimN
)
{
real
*
Sum
,
int
dimM
,
int
dimN
)
{
int
rowIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
rowIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
real
tmp
=
agg
.
init
();
real
tmp
=
agg
.
init
();
if
(
rowIdx
<
dimN
)
{
if
(
rowIdx
<
dimN
)
{
...
@@ -127,15 +104,12 @@ __global__ void KeMatrixColumnOp(Agg agg,
...
@@ -127,15 +104,12 @@ __global__ void KeMatrixColumnOp(Agg agg,
}
}
}
}
template
<
class
Agg
,
int
blockDimX
,
int
blockDimY
>
template
<
class
Agg
,
int
blockDimX
,
int
blockDimY
>
__global__
void
KeMatrixColumnOp_S
(
Agg
agg
,
__global__
void
KeMatrixColumnOp_S
(
real
*
E
,
Agg
agg
,
real
*
E
,
real
*
Sum
,
int
dimM
,
int
dimN
)
{
real
*
Sum
,
__shared__
real
_sum
[
blockDimX
*
blockDimY
];
int
dimM
,
int
rowIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
dimN
)
{
int
index
=
threadIdx
.
y
;
__shared__
real
_sum
[
blockDimX
*
blockDimY
];
int
rowIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
index
=
threadIdx
.
y
;
real
tmp
=
agg
.
init
();
real
tmp
=
agg
.
init
();
if
(
rowIdx
<
dimN
)
{
if
(
rowIdx
<
dimN
)
{
...
@@ -144,14 +118,14 @@ __global__ void KeMatrixColumnOp_S(Agg agg,
...
@@ -144,14 +118,14 @@ __global__ void KeMatrixColumnOp_S(Agg agg,
index
+=
blockDimY
;
index
+=
blockDimY
;
}
}
}
}
_sum
[
threadIdx
.
x
+
threadIdx
.
y
*
blockDimX
]
=
tmp
;
_sum
[
threadIdx
.
x
+
threadIdx
.
y
*
blockDimX
]
=
tmp
;
__syncthreads
();
__syncthreads
();
if
(
rowIdx
<
dimN
)
{
if
(
rowIdx
<
dimN
)
{
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
real
tmp
=
agg
.
init
();
real
tmp
=
agg
.
init
();
for
(
int
i
=
0
;
i
<
blockDimY
;
i
++
)
{
for
(
int
i
=
0
;
i
<
blockDimY
;
i
++
)
{
tmp
=
agg
(
tmp
,
_sum
[
threadIdx
.
x
+
i
*
blockDimX
]);
tmp
=
agg
(
tmp
,
_sum
[
threadIdx
.
x
+
i
*
blockDimX
]);
}
}
Sum
[
rowIdx
]
=
tmp
;
Sum
[
rowIdx
]
=
tmp
;
}
}
...
@@ -159,25 +133,21 @@ __global__ void KeMatrixColumnOp_S(Agg agg,
...
@@ -159,25 +133,21 @@ __global__ void KeMatrixColumnOp_S(Agg agg,
}
}
template
<
class
Agg
>
template
<
class
Agg
>
void
hl_matrix_column_op
(
Agg
agg
,
void
hl_matrix_column_op
(
Agg
agg
,
real
*
A_d
,
real
*
C_d
,
int
dimM
,
int
dimN
)
{
real
*
A_d
,
real
*
C_d
,
int
dimM
,
int
dimN
)
{
if
(
dimN
>=
8192
)
{
if
(
dimN
>=
8192
)
{
int
blocksX
=
(
dimN
+
128
-
1
)
/
128
;
int
blocksX
=
(
dimN
+
128
-
1
)
/
128
;
int
blocksY
=
1
;
int
blocksY
=
1
;
dim3
threads
(
128
,
1
);
dim3
threads
(
128
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
dim3
grid
(
blocksX
,
blocksY
);
KeMatrixColumnOp
<
Agg
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeMatrixColumnOp
<
Agg
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
agg
,
A_d
,
C_d
,
dimM
,
dimN
);
agg
,
A_d
,
C_d
,
dimM
,
dimN
);
}
else
{
}
else
{
int
blocksX
=
(
dimN
+
32
-
1
)
/
32
;
int
blocksX
=
(
dimN
+
32
-
1
)
/
32
;
int
blocksY
=
1
;
int
blocksY
=
1
;
dim3
threads
(
32
,
32
);
dim3
threads
(
32
,
32
);
dim3
grid
(
blocksX
,
blocksY
);
dim3
grid
(
blocksX
,
blocksY
);
KeMatrixColumnOp_S
<
Agg
,
32
,
32
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeMatrixColumnOp_S
<
Agg
,
32
,
32
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
agg
,
A_d
,
C_d
,
dimM
,
dimN
);
agg
,
A_d
,
C_d
,
dimM
,
dimN
);
}
}
return
;
return
;
...
@@ -187,11 +157,7 @@ void hl_matrix_column_sum(real *A_d, real *C_d, int dimM, int dimN) {
...
@@ -187,11 +157,7 @@ void hl_matrix_column_sum(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
C_d
);
CHECK_NOTNULL
(
C_d
);
hl_matrix_column_op
(
aggregate
::
sum
(),
hl_matrix_column_op
(
aggregate
::
sum
(),
A_d
,
C_d
,
dimM
,
dimN
);
A_d
,
C_d
,
dimM
,
dimN
);
CHECK_SYNC
(
"hl_matrix_column_sum failed"
);
CHECK_SYNC
(
"hl_matrix_column_sum failed"
);
}
}
...
@@ -200,11 +166,7 @@ void hl_matrix_column_max(real *A_d, real *C_d, int dimM, int dimN) {
...
@@ -200,11 +166,7 @@ void hl_matrix_column_max(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
C_d
);
CHECK_NOTNULL
(
C_d
);
hl_matrix_column_op
(
aggregate
::
max
(),
hl_matrix_column_op
(
aggregate
::
max
(),
A_d
,
C_d
,
dimM
,
dimN
);
A_d
,
C_d
,
dimM
,
dimN
);
CHECK_SYNC
(
"hl_matrix_column_max failed"
);
CHECK_SYNC
(
"hl_matrix_column_max failed"
);
}
}
...
@@ -213,11 +175,7 @@ void hl_matrix_column_min(real *A_d, real *C_d, int dimM, int dimN) {
...
@@ -213,11 +175,7 @@ void hl_matrix_column_min(real *A_d, real *C_d, int dimM, int dimN) {
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
A_d
);
CHECK_NOTNULL
(
C_d
);
CHECK_NOTNULL
(
C_d
);
hl_matrix_column_op
(
aggregate
::
min
(),
hl_matrix_column_op
(
aggregate
::
min
(),
A_d
,
C_d
,
dimM
,
dimN
);
A_d
,
C_d
,
dimM
,
dimN
);
CHECK_SYNC
(
"hl_matrix_column_min failed"
);
CHECK_SYNC
(
"hl_matrix_column_min failed"
);
}
}
...
@@ -226,16 +184,16 @@ template <int blockSize>
...
@@ -226,16 +184,16 @@ template <int blockSize>
__global__
void
KeVectorSum
(
real
*
E
,
real
*
Sum
,
int
dimM
)
{
__global__
void
KeVectorSum
(
real
*
E
,
real
*
Sum
,
int
dimM
)
{
__shared__
double
sum_s
[
blockSize
];
__shared__
double
sum_s
[
blockSize
];
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
index
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
index
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_s
[
tid
]
=
0.0
f
;
sum_s
[
tid
]
=
0.0
f
;
while
(
index
<
dimM
)
{
while
(
index
<
dimM
)
{
sum_s
[
tid
]
+=
E
[
index
];
sum_s
[
tid
]
+=
E
[
index
];
index
+=
blockDim
.
x
*
gridDim
.
y
;
index
+=
blockDim
.
x
*
gridDim
.
y
;
}
}
__syncthreads
();
__syncthreads
();
for
(
int
stride
=
blockSize
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
for
(
int
stride
=
blockSize
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
if
(
tid
<
stride
)
{
if
(
tid
<
stride
)
{
sum_s
[
tid
]
+=
sum_s
[
tid
+
stride
];
sum_s
[
tid
]
+=
sum_s
[
tid
+
stride
];
}
}
...
@@ -259,38 +217,39 @@ void hl_vector_sum(real *A_d, real *C_h, int dimM) {
...
@@ -259,38 +217,39 @@ void hl_vector_sum(real *A_d, real *C_h, int dimM) {
dim3
threads
(
blockSize
,
1
);
dim3
threads
(
blockSize
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
dim3
grid
(
blocksX
,
blocksY
);
struct
_hl_event_st
hl_event_st
=
{.
cu_event
=
t_resource
.
event
};
struct
_hl_event_st
hl_event_st
=
{.
cu_event
=
t_resource
.
event
};
hl_event_t
hl_event
=
&
hl_event_st
;
hl_event_t
hl_event
=
&
hl_event_st
;
while
(
!
hl_cuda_event_is_ready
(
hl_event
))
{}
while
(
!
hl_cuda_event_is_ready
(
hl_event
))
{
}
KeVectorSum
<
128
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeVectorSum
<
128
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
A_d
,
t_resource
.
gpu_mem
,
dimM
);
A_d
,
t_resource
.
gpu_mem
,
dimM
);
KeVectorSum
<
128
><<<
1
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeVectorSum
<
128
><<<
1
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
t_resource
.
gpu_mem
,
t_resource
.
cpu_mem
,
128
);
t_resource
.
gpu_mem
,
t_resource
.
cpu_mem
,
128
);
hl_memcpy_async
(
C_h
,
t_resource
.
cpu_mem
,
sizeof
(
real
),
HPPL_STREAM_DEFAULT
);
hl_memcpy_async
(
C_h
,
t_resource
.
cpu_mem
,
sizeof
(
real
),
HPPL_STREAM_DEFAULT
);
hl_stream_record_event
(
HPPL_STREAM_DEFAULT
,
hl_event
);
hl_stream_record_event
(
HPPL_STREAM_DEFAULT
,
hl_event
);
hl_stream_synchronize
(
HPPL_STREAM_DEFAULT
);
hl_stream_synchronize
(
HPPL_STREAM_DEFAULT
);
cudaError_t
err
=
(
cudaError_t
)
hl_get_device_last_error
();
cudaError_t
err
=
(
cudaError_t
)
hl_get_device_last_error
();
CHECK_EQ
(
cudaSuccess
,
err
)
CHECK_EQ
(
cudaSuccess
,
err
)
<<
"CUDA error: "
<<
"CUDA error: "
<<
hl_get_device_error_string
((
size_t
)
err
);
<<
hl_get_device_error_string
((
size_t
)
err
);
}
}
template
<
int
blockSize
>
template
<
int
blockSize
>
__global__
void
KeVectorAbsSum
(
real
*
E
,
real
*
Sum
,
int
dimM
)
{
__global__
void
KeVectorAbsSum
(
real
*
E
,
real
*
Sum
,
int
dimM
)
{
__shared__
double
sum_s
[
blockSize
];
__shared__
double
sum_s
[
blockSize
];
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
index
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
index
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_s
[
tid
]
=
0.0
f
;
sum_s
[
tid
]
=
0.0
f
;
while
(
index
<
dimM
)
{
while
(
index
<
dimM
)
{
sum_s
[
tid
]
+=
abs
(
E
[
index
]);
sum_s
[
tid
]
+=
abs
(
E
[
index
]);
index
+=
blockDim
.
x
*
gridDim
.
y
;
index
+=
blockDim
.
x
*
gridDim
.
y
;
}
}
__syncthreads
();
__syncthreads
();
for
(
int
stride
=
blockSize
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
for
(
int
stride
=
blockSize
/
2
;
stride
>
0
;
stride
=
stride
/
2
)
{
if
(
tid
<
stride
)
{
if
(
tid
<
stride
)
{
sum_s
[
tid
]
+=
sum_s
[
tid
+
stride
];
sum_s
[
tid
]
+=
sum_s
[
tid
+
stride
];
}
}
...
@@ -314,20 +273,21 @@ void hl_vector_abs_sum(real *A_d, real *C_h, int dimM) {
...
@@ -314,20 +273,21 @@ void hl_vector_abs_sum(real *A_d, real *C_h, int dimM) {
dim3
threads
(
blockSize
,
1
);
dim3
threads
(
blockSize
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
dim3
grid
(
blocksX
,
blocksY
);
struct
_hl_event_st
hl_event_st
=
{.
cu_event
=
t_resource
.
event
};
struct
_hl_event_st
hl_event_st
=
{.
cu_event
=
t_resource
.
event
};
hl_event_t
hl_event
=
&
hl_event_st
;
hl_event_t
hl_event
=
&
hl_event_st
;
while
(
!
hl_cuda_event_is_ready
(
hl_event
))
{}
while
(
!
hl_cuda_event_is_ready
(
hl_event
))
{
}
KeVectorAbsSum
<
128
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeVectorAbsSum
<
128
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
A_d
,
t_resource
.
gpu_mem
,
dimM
);
A_d
,
t_resource
.
gpu_mem
,
dimM
);
KeVectorAbsSum
<
128
><<<
1
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeVectorAbsSum
<
128
><<<
1
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
t_resource
.
gpu_mem
,
t_resource
.
cpu_mem
,
128
);
t_resource
.
gpu_mem
,
t_resource
.
cpu_mem
,
128
);
hl_memcpy_async
(
C_h
,
t_resource
.
cpu_mem
,
sizeof
(
real
),
HPPL_STREAM_DEFAULT
);
hl_memcpy_async
(
C_h
,
t_resource
.
cpu_mem
,
sizeof
(
real
),
HPPL_STREAM_DEFAULT
);
hl_stream_record_event
(
HPPL_STREAM_DEFAULT
,
hl_event
);
hl_stream_record_event
(
HPPL_STREAM_DEFAULT
,
hl_event
);
hl_stream_synchronize
(
HPPL_STREAM_DEFAULT
);
hl_stream_synchronize
(
HPPL_STREAM_DEFAULT
);
cudaError_t
err
=
(
cudaError_t
)
hl_get_device_last_error
();
cudaError_t
err
=
(
cudaError_t
)
hl_get_device_last_error
();
CHECK_EQ
(
cudaSuccess
,
err
)
CHECK_EQ
(
cudaSuccess
,
err
)
<<
"CUDA error: "
<<
"CUDA error: "
<<
hl_get_device_error_string
((
size_t
)
err
);
<<
hl_get_device_error_string
((
size_t
)
err
);
}
}
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/cuda/src/hl_cuda_lstm.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/cuda/src/hl_cuda_matrix.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/cuda/src/hl_cuda_sequence.cu
浏览文件 @
1d4fa243
...
@@ -16,36 +16,36 @@ limitations under the License. */
...
@@ -16,36 +16,36 @@ limitations under the License. */
#include "hl_device_functions.cuh"
#include "hl_device_functions.cuh"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Logging.h"
__global__
void
KeMaxSequenceForward
(
real
*
input
,
__global__
void
KeMaxSequenceForward
(
real
*
input
,
const
int
*
sequence
,
const
int
*
sequence
,
real
*
output
,
real
*
output
,
int
*
index
,
int
*
index
,
int
numSequences
,
int
numSequences
,
int
dim
)
{
int
dim
)
{
int
dimIdx
=
threadIdx
.
x
;
int
dimIdx
=
threadIdx
.
x
;
int
sequenceId
=
blockIdx
.
x
;
int
sequenceId
=
blockIdx
.
x
;
if
(
sequenceId
>=
numSequences
)
return
;
if
(
sequenceId
>=
numSequences
)
return
;
int
start
=
sequence
[
sequenceId
];
int
start
=
sequence
[
sequenceId
];
int
end
=
sequence
[
sequenceId
+
1
];
int
end
=
sequence
[
sequenceId
+
1
];
for
(
int
i
=
dimIdx
;
i
<
dim
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
dimIdx
;
i
<
dim
;
i
+=
blockDim
.
x
)
{
real
tmp
=
-
HL_FLOAT_MAX
;
real
tmp
=
-
HL_FLOAT_MAX
;
int
tmpId
=
-
1
;
int
tmpId
=
-
1
;
for
(
int
insId
=
start
;
insId
<
end
;
insId
++
)
{
for
(
int
insId
=
start
;
insId
<
end
;
insId
++
)
{
if
(
tmp
<
input
[
insId
*
dim
+
i
])
{
if
(
tmp
<
input
[
insId
*
dim
+
i
])
{
tmp
=
input
[
insId
*
dim
+
i
];
tmp
=
input
[
insId
*
dim
+
i
];
tmpId
=
insId
;
tmpId
=
insId
;
}
}
}
}
output
[
sequenceId
*
dim
+
i
]
=
tmp
;
output
[
sequenceId
*
dim
+
i
]
=
tmp
;
index
[
sequenceId
*
dim
+
i
]
=
tmpId
;
index
[
sequenceId
*
dim
+
i
]
=
tmpId
;
}
}
}
}
void
hl_max_sequence_forward
(
real
*
input
,
void
hl_max_sequence_forward
(
real
*
input
,
const
int
*
sequence
,
const
int
*
sequence
,
real
*
output
,
real
*
output
,
int
*
index
,
int
*
index
,
int
numSequences
,
int
numSequences
,
int
dim
)
{
int
dim
)
{
CHECK_NOTNULL
(
input
);
CHECK_NOTNULL
(
input
);
...
@@ -55,29 +55,23 @@ void hl_max_sequence_forward(real* input,
...
@@ -55,29 +55,23 @@ void hl_max_sequence_forward(real* input,
dim3
threads
(
256
,
1
);
dim3
threads
(
256
,
1
);
dim3
grid
(
numSequences
,
1
);
dim3
grid
(
numSequences
,
1
);
KeMaxSequenceForward
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeMaxSequenceForward
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
input
,
sequence
,
output
,
index
,
numSequences
,
dim
);
input
,
sequence
,
output
,
index
,
numSequences
,
dim
);
CHECK_SYNC
(
"hl_max_sequence_forward failed"
);
CHECK_SYNC
(
"hl_max_sequence_forward failed"
);
}
}
__global__
void
KeMaxSequenceBackward
(
real
*
outputGrad
,
__global__
void
KeMaxSequenceBackward
(
int
*
index
,
real
*
outputGrad
,
int
*
index
,
real
*
inputGrad
,
int
numSequences
,
int
dim
)
{
real
*
inputGrad
,
int
numSequences
,
int
dim
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
colIdx
=
idx
%
dim
;
int
colIdx
=
idx
%
dim
;
if
(
idx
<
numSequences
*
dim
)
{
if
(
idx
<
numSequences
*
dim
)
{
int
insId
=
index
[
idx
];
int
insId
=
index
[
idx
];
inputGrad
[
insId
*
dim
+
colIdx
]
+=
outputGrad
[
idx
];
inputGrad
[
insId
*
dim
+
colIdx
]
+=
outputGrad
[
idx
];
}
}
}
}
void
hl_max_sequence_backward
(
real
*
outputGrad
,
void
hl_max_sequence_backward
(
int
*
index
,
real
*
outputGrad
,
int
*
index
,
real
*
inputGrad
,
int
numSequences
,
int
dim
)
{
real
*
inputGrad
,
int
numSequences
,
int
dim
)
{
CHECK_NOTNULL
(
outputGrad
);
CHECK_NOTNULL
(
outputGrad
);
CHECK_NOTNULL
(
index
);
CHECK_NOTNULL
(
index
);
CHECK_NOTNULL
(
inputGrad
);
CHECK_NOTNULL
(
inputGrad
);
...
@@ -85,12 +79,12 @@ void hl_max_sequence_backward(real* outputGrad,
...
@@ -85,12 +79,12 @@ void hl_max_sequence_backward(real* outputGrad,
unsigned
int
blocks
=
(
numSequences
*
dim
+
128
-
1
)
/
128
;
unsigned
int
blocks
=
(
numSequences
*
dim
+
128
-
1
)
/
128
;
dim3
threads
(
128
,
1
);
dim3
threads
(
128
,
1
);
dim3
grid
(
blocks
,
1
);
dim3
grid
(
blocks
,
1
);
KeMaxSequenceBackward
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeMaxSequenceBackward
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
outputGrad
,
index
,
inputGrad
,
numSequences
,
dim
);
outputGrad
,
index
,
inputGrad
,
numSequences
,
dim
);
CHECK_SYNC
(
"hl_max_sequence_backward failed"
);
CHECK_SYNC
(
"hl_max_sequence_backward failed"
);
}
}
template
<
int
blockDimX
,
int
blockDimY
,
int
gridDimX
,
bool
AddRow
>
template
<
int
blockDimX
,
int
blockDimY
,
int
gridDimX
,
bool
AddRow
>
__global__
void
KeMatrixAddRows
(
real
*
output
,
__global__
void
KeMatrixAddRows
(
real
*
output
,
real
*
table
,
real
*
table
,
int
*
ids
,
int
*
ids
,
...
@@ -104,8 +98,8 @@ __global__ void KeMatrixAddRows(real* output,
...
@@ -104,8 +98,8 @@ __global__ void KeMatrixAddRows(real* output,
while
(
sampleId
<
numSamples
)
{
while
(
sampleId
<
numSamples
)
{
int
tableId
=
ids
[
sampleId
];
int
tableId
=
ids
[
sampleId
];
if
((
0
<=
tableId
)
&&
(
tableId
<
tableSize
))
{
if
((
0
<=
tableId
)
&&
(
tableId
<
tableSize
))
{
real
*
outputData
=
output
+
sampleId
*
dim
;
real
*
outputData
=
output
+
sampleId
*
dim
;
real
*
tableData
=
table
+
tableId
*
dim
;
real
*
tableData
=
table
+
tableId
*
dim
;
for
(
int
i
=
idx
;
i
<
dim
;
i
+=
blockDimX
)
{
for
(
int
i
=
idx
;
i
<
dim
;
i
+=
blockDimX
)
{
if
(
AddRow
==
0
)
{
if
(
AddRow
==
0
)
{
outputData
[
i
]
+=
tableData
[
i
];
outputData
[
i
]
+=
tableData
[
i
];
...
@@ -114,24 +108,27 @@ __global__ void KeMatrixAddRows(real* output,
...
@@ -114,24 +108,27 @@ __global__ void KeMatrixAddRows(real* output,
}
}
}
}
}
}
sampleId
+=
blockDimY
*
gridDimX
;
sampleId
+=
blockDimY
*
gridDimX
;
}
}
}
}
template
<
int
blockDimX
,
int
blockDimY
,
int
gridDimX
,
bool
seq2batch
,
bool
isAdd
>
template
<
int
blockDimX
,
__global__
int
blockDimY
,
void
KeSequence2Batch
(
real
*
batch
,
int
gridDimX
,
real
*
sequence
,
bool
seq2batch
,
const
int
*
batchIndex
,
bool
isAdd
>
int
seqWidth
,
__global__
void
KeSequence2Batch
(
real
*
batch
,
int
batchCount
)
{
real
*
sequence
,
const
int
*
batchIndex
,
int
seqWidth
,
int
batchCount
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
int
idy
=
threadIdx
.
y
;
int
idy
=
threadIdx
.
y
;
int
id
=
blockIdx
.
x
+
idy
*
gridDimX
;
int
id
=
blockIdx
.
x
+
idy
*
gridDimX
;
while
(
id
<
batchCount
)
{
while
(
id
<
batchCount
)
{
int
seqId
=
batchIndex
[
id
];
int
seqId
=
batchIndex
[
id
];
real
*
batchData
=
batch
+
id
*
seqWidth
;
real
*
batchData
=
batch
+
id
*
seqWidth
;
real
*
seqData
=
sequence
+
seqId
*
seqWidth
;
real
*
seqData
=
sequence
+
seqId
*
seqWidth
;
for
(
int
i
=
idx
;
i
<
seqWidth
;
i
+=
blockDimX
)
{
for
(
int
i
=
idx
;
i
<
seqWidth
;
i
+=
blockDimX
)
{
if
(
seq2batch
)
{
if
(
seq2batch
)
{
if
(
isAdd
)
{
if
(
isAdd
)
{
...
@@ -147,13 +144,13 @@ void KeSequence2Batch(real *batch,
...
@@ -147,13 +144,13 @@ void KeSequence2Batch(real *batch,
}
}
}
}
}
}
id
+=
blockDimY
*
gridDimX
;
id
+=
blockDimY
*
gridDimX
;
}
}
}
}
void
hl_sequence2batch_copy
(
real
*
batch
,
void
hl_sequence2batch_copy
(
real
*
batch
,
real
*
sequence
,
real
*
sequence
,
const
int
*
batchIndex
,
const
int
*
batchIndex
,
int
seqWidth
,
int
seqWidth
,
int
batchCount
,
int
batchCount
,
bool
seq2batch
)
{
bool
seq2batch
)
{
...
@@ -164,18 +161,18 @@ void hl_sequence2batch_copy(real *batch,
...
@@ -164,18 +161,18 @@ void hl_sequence2batch_copy(real *batch,
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grid
(
8
,
1
);
dim3
grid
(
8
,
1
);
if
(
seq2batch
)
{
if
(
seq2batch
)
{
KeSequence2Batch
<
128
,
8
,
8
,
1
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeSequence2Batch
<
128
,
8
,
8
,
1
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
batch
,
sequence
,
batchIndex
,
seqWidth
,
batchCount
);
batch
,
sequence
,
batchIndex
,
seqWidth
,
batchCount
);
}
else
{
}
else
{
KeSequence2Batch
<
128
,
8
,
8
,
0
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeSequence2Batch
<
128
,
8
,
8
,
0
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
batch
,
sequence
,
batchIndex
,
seqWidth
,
batchCount
);
batch
,
sequence
,
batchIndex
,
seqWidth
,
batchCount
);
}
}
CHECK_SYNC
(
"hl_sequence2batch_copy failed"
);
CHECK_SYNC
(
"hl_sequence2batch_copy failed"
);
}
}
void
hl_sequence2batch_add
(
real
*
batch
,
void
hl_sequence2batch_add
(
real
*
batch
,
real
*
sequence
,
real
*
sequence
,
int
*
batchIndex
,
int
*
batchIndex
,
int
seqWidth
,
int
seqWidth
,
int
batchCount
,
int
batchCount
,
bool
seq2batch
)
{
bool
seq2batch
)
{
...
@@ -186,23 +183,22 @@ void hl_sequence2batch_add(real *batch,
...
@@ -186,23 +183,22 @@ void hl_sequence2batch_add(real *batch,
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grid
(
8
,
1
);
dim3
grid
(
8
,
1
);
if
(
seq2batch
)
{
if
(
seq2batch
)
{
KeSequence2Batch
<
128
,
8
,
8
,
1
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeSequence2Batch
<
128
,
8
,
8
,
1
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
batch
,
sequence
,
batchIndex
,
seqWidth
,
batchCount
);
batch
,
sequence
,
batchIndex
,
seqWidth
,
batchCount
);
}
else
{
}
else
{
KeSequence2Batch
<
128
,
8
,
8
,
0
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeSequence2Batch
<
128
,
8
,
8
,
0
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
batch
,
sequence
,
batchIndex
,
seqWidth
,
batchCount
);
batch
,
sequence
,
batchIndex
,
seqWidth
,
batchCount
);
}
}
CHECK_SYNC
(
"hl_sequence2batch_add failed"
);
CHECK_SYNC
(
"hl_sequence2batch_add failed"
);
}
}
template
<
bool
normByTimes
,
bool
seq2batch
>
template
<
bool
normByTimes
,
bool
seq2batch
>
__global__
__global__
void
KeSequence2BatchPadding
(
real
*
batch
,
void
KeSequence2BatchPadding
(
real
*
batch
,
real
*
sequence
,
real
*
sequence
,
const
int
*
sequenceStartPositions
,
const
int
*
sequenceStartPositions
,
const
size_t
sequenceWidth
,
const
size_t
sequenceWidth
,
const
size_t
maxSequenceLength
,
const
size_t
maxSequenceLength
,
const
size_t
numSequences
)
{
const
size_t
numSequences
)
{
int
batchIdx
=
blockIdx
.
y
;
int
batchIdx
=
blockIdx
.
y
;
int
sequenceStart
=
sequenceStartPositions
[
batchIdx
];
int
sequenceStart
=
sequenceStartPositions
[
batchIdx
];
int
sequenceLength
=
sequenceStartPositions
[
batchIdx
+
1
]
-
sequenceStart
;
int
sequenceLength
=
sequenceStartPositions
[
batchIdx
+
1
]
-
sequenceStart
;
...
@@ -276,37 +272,49 @@ void hl_sequence2batch_copy_padding(real* batch,
...
@@ -276,37 +272,49 @@ void hl_sequence2batch_copy_padding(real* batch,
if
(
seq2batch
)
{
if
(
seq2batch
)
{
/* sequence -> batch */
/* sequence -> batch */
if
(
normByTimes
)
{
if
(
normByTimes
)
{
KeSequence2BatchPadding
<
1
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
KeSequence2BatchPadding
<
1
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
batch
,
sequence
,
sequenceStartPositions
,
batch
,
sequenceWidth
,
maxSequenceLength
,
numSequences
);
sequence
,
sequenceStartPositions
,
sequenceWidth
,
maxSequenceLength
,
numSequences
);
}
else
{
}
else
{
KeSequence2BatchPadding
<
0
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
KeSequence2BatchPadding
<
0
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
batch
,
sequence
,
sequenceStartPositions
,
batch
,
sequenceWidth
,
maxSequenceLength
,
numSequences
);
sequence
,
sequenceStartPositions
,
sequenceWidth
,
maxSequenceLength
,
numSequences
);
}
}
}
else
{
}
else
{
/* batch -> sequence */
/* batch -> sequence */
if
(
normByTimes
)
{
if
(
normByTimes
)
{
KeSequence2BatchPadding
<
1
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
KeSequence2BatchPadding
<
1
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
batch
,
sequence
,
sequenceStartPositions
,
batch
,
sequenceWidth
,
maxSequenceLength
,
numSequences
);
sequence
,
sequenceStartPositions
,
sequenceWidth
,
maxSequenceLength
,
numSequences
);
}
else
{
}
else
{
KeSequence2BatchPadding
<
0
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
KeSequence2BatchPadding
<
0
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
batch
,
sequence
,
sequenceStartPositions
,
batch
,
sequenceWidth
,
maxSequenceLength
,
numSequences
);
sequence
,
sequenceStartPositions
,
sequenceWidth
,
maxSequenceLength
,
numSequences
);
}
}
}
}
CHECK_SYNC
(
"hl_sequence2batch_copy_padding failed"
);
CHECK_SYNC
(
"hl_sequence2batch_copy_padding failed"
);
}
}
__device__
inline
float
my_rsqrt
(
float
x
)
{
__device__
inline
float
my_rsqrt
(
float
x
)
{
return
rsqrtf
(
x
);
}
return
rsqrtf
(
x
);
}
__device__
inline
double
my_rsqrt
(
double
x
)
{
__device__
inline
double
my_rsqrt
(
double
x
)
{
return
rsqrt
(
x
);
}
return
rsqrt
(
x
);
}
__global__
void
KeSequenceAvgForward
(
real
*
dst
,
__global__
void
KeSequenceAvgForward
(
real
*
dst
,
real
*
src
,
real
*
src
,
...
@@ -327,8 +335,8 @@ __global__ void KeSequenceAvgForward(real* dst,
...
@@ -327,8 +335,8 @@ __global__ void KeSequenceAvgForward(real* dst,
for
(
int
i
=
start
;
i
<
end
;
i
++
)
{
for
(
int
i
=
start
;
i
<
end
;
i
++
)
{
sum
+=
src
[
i
*
width
+
col
];
sum
+=
src
[
i
*
width
+
col
];
}
}
sum
=
mode
==
1
?
sum
:
sum
=
mode
==
1
?
sum
:
(
mode
==
0
?
sum
/
seqLength
(
mode
==
0
?
sum
/
seqLength
:
sum
*
my_rsqrt
((
real
)
seqLength
));
:
sum
*
my_rsqrt
((
real
)
seqLength
));
dst
[
gid
]
+=
sum
;
dst
[
gid
]
+=
sum
;
}
}
}
}
...
@@ -347,10 +355,10 @@ void hl_sequence_avg_forward(real* dst,
...
@@ -347,10 +355,10 @@ void hl_sequence_avg_forward(real* dst,
int
grid
=
DIVUP
(
width
*
height
,
512
);
int
grid
=
DIVUP
(
width
*
height
,
512
);
CHECK
(
mode
==
0
||
mode
==
1
||
mode
==
2
)
CHECK
(
mode
==
0
||
mode
==
1
||
mode
==
2
)
<<
"mode error in hl_sequence_avg_forward!"
;
<<
"mode error in hl_sequence_avg_forward!"
;
KeSequenceAvgForward
<<<
grid
,
block
,
0
,
STREAM_DEFAULT
>>>
KeSequenceAvgForward
<<<
grid
,
block
,
0
,
STREAM_DEFAULT
>>>
(
(
dst
,
src
,
starts
,
height
,
width
,
mode
);
dst
,
src
,
starts
,
height
,
width
,
mode
);
CHECK_SYNC
(
"hl_sequence_avg_forward failed"
);
CHECK_SYNC
(
"hl_sequence_avg_forward failed"
);
}
}
...
@@ -370,8 +378,8 @@ __global__ void KeSequenceAvgBackward(real* dst,
...
@@ -370,8 +378,8 @@ __global__ void KeSequenceAvgBackward(real* dst,
int
seqLength
=
end
-
start
;
int
seqLength
=
end
-
start
;
if
(
seqLength
==
0
)
return
;
if
(
seqLength
==
0
)
return
;
real
grad
=
src
[
gid
];
real
grad
=
src
[
gid
];
grad
=
mode
==
1
?
grad
:
grad
=
mode
==
1
?
grad
:
(
mode
==
0
?
grad
/
seqLength
(
mode
==
0
?
grad
/
seqLength
:
grad
*
my_rsqrt
((
real
)
seqLength
));
:
grad
*
my_rsqrt
((
real
)
seqLength
));
for
(
int
i
=
start
;
i
<
end
;
i
++
)
{
for
(
int
i
=
start
;
i
<
end
;
i
++
)
{
dst
[
i
*
width
+
col
]
+=
grad
;
dst
[
i
*
width
+
col
]
+=
grad
;
}
}
...
@@ -392,9 +400,9 @@ void hl_sequence_avg_backward(real* dst,
...
@@ -392,9 +400,9 @@ void hl_sequence_avg_backward(real* dst,
int
grid
=
DIVUP
(
width
*
height
,
512
);
int
grid
=
DIVUP
(
width
*
height
,
512
);
CHECK
(
mode
==
0
||
mode
==
1
||
mode
==
2
)
CHECK
(
mode
==
0
||
mode
==
1
||
mode
==
2
)
<<
"mode error in hl_sequence_avg_backward!"
;
<<
"mode error in hl_sequence_avg_backward!"
;
KeSequenceAvgBackward
<<<
grid
,
block
,
0
,
STREAM_DEFAULT
>>>
KeSequenceAvgBackward
<<<
grid
,
block
,
0
,
STREAM_DEFAULT
>>>
(
(
dst
,
src
,
starts
,
height
,
width
,
mode
);
dst
,
src
,
starts
,
height
,
width
,
mode
);
CHECK_SYNC
(
"hl_sequence_avg_backward failed"
);
CHECK_SYNC
(
"hl_sequence_avg_backward failed"
);
}
}
paddle/cuda/src/hl_cuda_sparse.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/cuda/src/hl_perturbation_util.cu
浏览文件 @
1d4fa243
...
@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <cmath>
#include <stdlib.h>
#include <stdlib.h>
#include "hl_cuda.h"
#include <cmath>
#include "hl_time.h"
#include "hl_base.h"
#include "hl_base.h"
#include "hl_cuda.h"
#include "hl_perturbation_util.cuh"
#include "hl_perturbation_util.cuh"
#include "hl_time.h"
#define _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
...
@@ -30,10 +29,16 @@ limitations under the License. */
...
@@ -30,10 +29,16 @@ limitations under the License. */
* centerX, centerY: translation.
* centerX, centerY: translation.
* sourceX, sourceY: output coordinates in the original image.
* sourceX, sourceY: output coordinates in the original image.
*/
*/
__device__
void
getTranformCoord
(
int
x
,
int
y
,
real
theta
,
real
scale
,
__device__
void
getTranformCoord
(
int
x
,
real
tgtCenter
,
real
imgCenter
,
int
y
,
real
centerR
,
real
centerC
,
real
theta
,
int
*
sourceX
,
int
*
sourceY
)
{
real
scale
,
real
tgtCenter
,
real
imgCenter
,
real
centerR
,
real
centerC
,
int
*
sourceX
,
int
*
sourceY
)
{
real
H
[
4
]
=
{
cosf
(
-
theta
),
-
sinf
(
-
theta
),
sinf
(
-
theta
),
cosf
(
-
theta
)};
real
H
[
4
]
=
{
cosf
(
-
theta
),
-
sinf
(
-
theta
),
sinf
(
-
theta
),
cosf
(
-
theta
)};
// compute coornidates in the rotated and scaled image
// compute coornidates in the rotated and scaled image
...
@@ -57,11 +62,17 @@ __device__ void getTranformCoord(int x, int y, real theta, real scale,
...
@@ -57,11 +62,17 @@ __device__ void getTranformCoord(int x, int y, real theta, real scale,
* created by Wei Xu (genome), converted by Jiang Wang
* created by Wei Xu (genome), converted by Jiang Wang
*/
*/
__global__
void
kSamplingPatches
(
const
real
*
imgs
,
real
*
targets
,
__global__
void
kSamplingPatches
(
const
real
*
imgs
,
int
imgSize
,
int
tgtSize
,
const
int
channels
,
real
*
targets
,
int
samplingRate
,
const
real
*
thetas
,
int
imgSize
,
const
real
*
scales
,
const
int
*
centerRs
,
int
tgtSize
,
const
int
*
centerCs
,
const
real
padValue
,
const
int
channels
,
int
samplingRate
,
const
real
*
thetas
,
const
real
*
scales
,
const
int
*
centerRs
,
const
int
*
centerCs
,
const
real
padValue
,
const
int
numImages
)
{
const
int
numImages
)
{
const
int
caseIdx
=
blockIdx
.
x
*
4
+
threadIdx
.
x
;
const
int
caseIdx
=
blockIdx
.
x
*
4
+
threadIdx
.
x
;
const
int
pxIdx
=
blockIdx
.
y
*
128
+
threadIdx
.
y
;
const
int
pxIdx
=
blockIdx
.
y
*
128
+
threadIdx
.
y
;
...
@@ -80,8 +91,15 @@ __global__ void kSamplingPatches(const real* imgs, real* targets,
...
@@ -80,8 +91,15 @@ __global__ void kSamplingPatches(const real* imgs, real* targets,
const
int
pxY
=
pxIdx
/
tgtSize
;
const
int
pxY
=
pxIdx
/
tgtSize
;
int
srcPxX
,
srcPxY
;
int
srcPxX
,
srcPxY
;
getTranformCoord
(
pxX
,
pxY
,
thetas
[
imgIdx
],
scales
[
imgIdx
],
tgtCenter
,
getTranformCoord
(
pxX
,
imgCenter
,
centerCs
[
caseIdx
],
centerRs
[
caseIdx
],
&
srcPxX
,
pxY
,
thetas
[
imgIdx
],
scales
[
imgIdx
],
tgtCenter
,
imgCenter
,
centerCs
[
caseIdx
],
centerRs
[
caseIdx
],
&
srcPxX
,
&
srcPxY
);
&
srcPxY
);
imgs
+=
(
imgIdx
*
imgPixels
+
srcPxY
*
imgSize
+
srcPxX
)
*
channels
;
imgs
+=
(
imgIdx
*
imgPixels
+
srcPxY
*
imgSize
+
srcPxX
)
*
channels
;
...
@@ -100,10 +118,15 @@ __global__ void kSamplingPatches(const real* imgs, real* targets,
...
@@ -100,10 +118,15 @@ __global__ void kSamplingPatches(const real* imgs, real* targets,
*
*
* created by Wei Xu
* created by Wei Xu
*/
*/
void
hl_generate_disturb_params
(
real
*&
gpuAngle
,
real
*&
gpuScaleRatio
,
void
hl_generate_disturb_params
(
real
*&
gpuAngle
,
int
*&
gpuCenterR
,
int
*&
gpuCenterC
,
real
*&
gpuScaleRatio
,
int
numImages
,
int
imgSize
,
real
rotateAngle
,
int
*&
gpuCenterR
,
real
scaleRatio
,
int
samplingRate
,
int
*&
gpuCenterC
,
int
numImages
,
int
imgSize
,
real
rotateAngle
,
real
scaleRatio
,
int
samplingRate
,
bool
isTrain
)
{
bool
isTrain
)
{
// The number of output samples.
// The number of output samples.
int
numPatches
=
numImages
*
samplingRate
;
int
numPatches
=
numImages
*
samplingRate
;
...
@@ -123,7 +146,8 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
...
@@ -123,7 +146,8 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
for
(
int
i
=
0
;
i
<
numImages
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numImages
;
i
++
)
{
r_angle
[
i
]
=
r_angle
[
i
]
=
(
rotateAngle
*
M_PI
/
180.0
)
*
(
rand
()
/
(
RAND_MAX
+
1.0
)
// NOLINT
(
rotateAngle
*
M_PI
/
180.0
)
*
(
rand
()
/
(
RAND_MAX
+
1.0
)
// NOLINT
-
0.5
);
-
0.5
);
s_ratio
[
i
]
=
s_ratio
[
i
]
=
1
+
(
rand
()
/
(
RAND_MAX
+
1.0
)
-
0.5
)
*
scaleRatio
;
// NOLINT
1
+
(
rand
()
/
(
RAND_MAX
+
1.0
)
-
0.5
)
*
scaleRatio
;
// NOLINT
}
}
...
@@ -140,8 +164,10 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
...
@@ -140,8 +164,10 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
int
pxY
=
int
pxY
=
(
int
)(
real
(
imgSize
-
1
)
*
rand
()
/
(
RAND_MAX
+
1.0
));
// NOLINT
(
int
)(
real
(
imgSize
-
1
)
*
rand
()
/
(
RAND_MAX
+
1.0
));
// NOLINT
const
real
H
[
4
]
=
{
cos
(
-
r_angle
[
i
]),
-
sin
(
-
r_angle
[
i
]),
const
real
H
[
4
]
=
{
cos
(
-
r_angle
[
i
]),
sin
(
-
r_angle
[
i
]),
cos
(
-
r_angle
[
i
])};
-
sin
(
-
r_angle
[
i
]),
sin
(
-
r_angle
[
i
]),
cos
(
-
r_angle
[
i
])};
real
x
=
pxX
-
imgCenter
;
real
x
=
pxX
-
imgCenter
;
real
y
=
pxY
-
imgCenter
;
real
y
=
pxY
-
imgCenter
;
real
xx
=
H
[
0
]
*
x
+
H
[
1
]
*
y
;
real
xx
=
H
[
0
]
*
x
+
H
[
1
]
*
y
;
...
@@ -185,9 +211,12 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
...
@@ -185,9 +211,12 @@ void hl_generate_disturb_params(real*& gpuAngle, real*& gpuScaleRatio,
delete
[]
center_c
;
delete
[]
center_c
;
}
}
void
hl_conv_random_disturb_with_params
(
const
real
*
images
,
int
imgSize
,
void
hl_conv_random_disturb_with_params
(
const
real
*
images
,
int
tgtSize
,
int
channels
,
int
imgSize
,
int
numImages
,
int
samplingRate
,
int
tgtSize
,
int
channels
,
int
numImages
,
int
samplingRate
,
const
real
*
gpuRotationAngle
,
const
real
*
gpuRotationAngle
,
const
real
*
gpuScaleRatio
,
const
real
*
gpuScaleRatio
,
const
int
*
gpuCenterR
,
const
int
*
gpuCenterR
,
...
@@ -202,29 +231,59 @@ void hl_conv_random_disturb_with_params(const real* images, int imgSize,
...
@@ -202,29 +231,59 @@ void hl_conv_random_disturb_with_params(const real* images, int imgSize,
dim3
threadsPerBlock
(
4
,
128
);
dim3
threadsPerBlock
(
4
,
128
);
dim3
numBlocks
(
DIVUP
(
numPatches
,
4
),
DIVUP
(
targetSize
,
128
));
dim3
numBlocks
(
DIVUP
(
numPatches
,
4
),
DIVUP
(
targetSize
,
128
));
kSamplingPatches
<<<
numBlocks
,
threadsPerBlock
>>>
kSamplingPatches
<<<
numBlocks
,
threadsPerBlock
>>>
(
images
,
(
images
,
target
,
imgSize
,
tgtSize
,
channels
,
samplingRate
,
target
,
gpuRotationAngle
,
gpuScaleRatio
,
gpuCenterR
,
gpuCenterC
,
imgSize
,
paddingValue
,
numImages
);
tgtSize
,
channels
,
samplingRate
,
gpuRotationAngle
,
gpuScaleRatio
,
gpuCenterR
,
gpuCenterC
,
paddingValue
,
numImages
);
hl_device_synchronize
();
hl_device_synchronize
();
}
}
void
hl_conv_random_disturb
(
const
real
*
images
,
int
imgSize
,
void
hl_conv_random_disturb
(
const
real
*
images
,
int
tgtSize
,
int
channels
,
int
numImages
,
int
imgSize
,
real
scaleRatio
,
real
rotateAngle
,
int
tgtSize
,
int
samplingRate
,
real
*
gpu_r_angle
,
int
channels
,
real
*
gpu_s_ratio
,
int
*
gpu_center_r
,
int
numImages
,
int
*
gpu_center_c
,
int
paddingValue
,
real
scaleRatio
,
bool
isTrain
,
real
*
targets
)
{
real
rotateAngle
,
int
samplingRate
,
real
*
gpu_r_angle
,
real
*
gpu_s_ratio
,
int
*
gpu_center_r
,
int
*
gpu_center_c
,
int
paddingValue
,
bool
isTrain
,
real
*
targets
)
{
// generate the random disturbance sequence and the sampling locations
// generate the random disturbance sequence and the sampling locations
hl_generate_disturb_params
(
gpu_r_angle
,
gpu_s_ratio
,
gpu_center_r
,
hl_generate_disturb_params
(
gpu_r_angle
,
gpu_center_c
,
numImages
,
imgSize
,
rotateAngle
,
gpu_s_ratio
,
scaleRatio
,
samplingRate
,
isTrain
);
gpu_center_r
,
gpu_center_c
,
hl_conv_random_disturb_with_params
(
numImages
,
images
,
imgSize
,
tgtSize
,
channels
,
numImages
,
imgSize
,
samplingRate
,
gpu_r_angle
,
gpu_s_ratio
,
rotateAngle
,
gpu_center_r
,
gpu_center_r
,
paddingValue
,
scaleRatio
,
targets
);
samplingRate
,
isTrain
);
hl_conv_random_disturb_with_params
(
images
,
imgSize
,
tgtSize
,
channels
,
numImages
,
samplingRate
,
gpu_r_angle
,
gpu_s_ratio
,
gpu_center_r
,
gpu_center_r
,
paddingValue
,
targets
);
}
}
paddle/cuda/src/hl_table_apply.cu
浏览文件 @
1d4fa243
...
@@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "hl_base.h"
#include "hl_base.h"
#include "hl_device_functions.cuh"
#include "hl_cuda.h"
#include "hl_cuda.h"
#include "hl_device_functions.cuh"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Logging.h"
template
<
int
blockDimX
,
int
blockDimY
,
int
gridDimX
,
bool
AddRow
>
template
<
int
blockDimX
,
int
blockDimY
,
int
gridDimX
,
bool
AddRow
>
__global__
void
KeMatrixAddRows
(
real
*
output
,
int
ldo
,
__global__
void
KeMatrixAddRows
(
real
*
output
,
real
*
table
,
int
ldt
,
int
ldo
,
real
*
table
,
int
ldt
,
int
*
ids
,
int
*
ids
,
int
numSamples
,
int
numSamples
,
int
tableSize
,
int
tableSize
,
...
@@ -31,8 +32,8 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
...
@@ -31,8 +32,8 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
while
(
idy
<
numSamples
)
{
while
(
idy
<
numSamples
)
{
int
tableId
=
ids
[
idy
];
int
tableId
=
ids
[
idy
];
if
((
0
<=
tableId
)
&&
(
tableId
<
tableSize
))
{
if
((
0
<=
tableId
)
&&
(
tableId
<
tableSize
))
{
real
*
out
=
output
+
idy
*
ldo
;
real
*
out
=
output
+
idy
*
ldo
;
real
*
tab
=
table
+
tableId
*
ldt
;
real
*
tab
=
table
+
tableId
*
ldt
;
for
(
int
i
=
idx
;
i
<
dim
;
i
+=
blockDimX
)
{
for
(
int
i
=
idx
;
i
<
dim
;
i
+=
blockDimX
)
{
if
(
AddRow
)
{
if
(
AddRow
)
{
paddle
::
paddleAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
paddle
::
paddleAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
...
@@ -45,8 +46,10 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
...
@@ -45,8 +46,10 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
}
}
}
}
void
hl_matrix_select_rows
(
real
*
output
,
int
ldo
,
void
hl_matrix_select_rows
(
real
*
output
,
real
*
table
,
int
ldt
,
int
ldo
,
real
*
table
,
int
ldt
,
int
*
ids
,
int
*
ids
,
int
numSamples
,
int
numSamples
,
int
tableSize
,
int
tableSize
,
...
@@ -57,14 +60,16 @@ void hl_matrix_select_rows(real* output, int ldo,
...
@@ -57,14 +60,16 @@ void hl_matrix_select_rows(real* output, int ldo,
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grid
(
8
,
1
);
dim3
grid
(
8
,
1
);
KeMatrixAddRows
<
128
,
8
,
8
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeMatrixAddRows
<
128
,
8
,
8
,
0
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
output
,
ldo
,
table
,
ldt
,
ids
,
numSamples
,
tableSize
,
dim
);
output
,
ldo
,
table
,
ldt
,
ids
,
numSamples
,
tableSize
,
dim
);
CHECK_SYNC
(
"hl_matrix_select_rows failed"
);
CHECK_SYNC
(
"hl_matrix_select_rows failed"
);
}
}
void
hl_matrix_add_to_rows
(
real
*
table
,
int
ldt
,
void
hl_matrix_add_to_rows
(
real
*
table
,
real
*
input
,
int
ldi
,
int
ldt
,
real
*
input
,
int
ldi
,
int
*
ids
,
int
*
ids
,
int
numSamples
,
int
numSamples
,
int
tableSize
,
int
tableSize
,
...
@@ -75,16 +80,15 @@ void hl_matrix_add_to_rows(real* table, int ldt,
...
@@ -75,16 +80,15 @@ void hl_matrix_add_to_rows(real* table, int ldt,
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grid
(
8
,
1
);
dim3
grid
(
8
,
1
);
KeMatrixAddRows
<
128
,
8
,
8
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeMatrixAddRows
<
128
,
8
,
8
,
1
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
input
,
ldi
,
table
,
ldt
,
ids
,
numSamples
,
tableSize
,
dim
);
input
,
ldi
,
table
,
ldt
,
ids
,
numSamples
,
tableSize
,
dim
);
CHECK_SYNC
(
"hl_matrix_add_to_rows failed"
);
CHECK_SYNC
(
"hl_matrix_add_to_rows failed"
);
}
}
template
<
class
T
,
int
blockDimX
,
int
gridDimX
>
template
<
class
T
,
int
blockDimX
,
int
gridDimX
>
__global__
void
KeVectorSelect
(
T
*
dst
,
int
sized
,
__global__
void
KeVectorSelect
(
const
T
*
src
,
int
sizes
,
T
*
dst
,
int
sized
,
const
T
*
src
,
int
sizes
,
const
int
*
ids
,
int
sizei
)
{
const
int
*
ids
,
int
sizei
)
{
int
idx
=
threadIdx
.
x
+
blockDimX
*
blockIdx
.
x
;
int
idx
=
threadIdx
.
x
+
blockDimX
*
blockIdx
.
x
;
while
(
idx
<
sizei
)
{
while
(
idx
<
sizei
)
{
int
index
=
ids
[
idx
];
int
index
=
ids
[
idx
];
...
@@ -95,9 +99,8 @@ __global__ void KeVectorSelect(T* dst, int sized,
...
@@ -95,9 +99,8 @@ __global__ void KeVectorSelect(T* dst, int sized,
}
}
template
<
class
T
>
template
<
class
T
>
void
hl_vector_select_from
(
T
*
dst
,
int
sized
,
void
hl_vector_select_from
(
const
T
*
src
,
int
sizes
,
T
*
dst
,
int
sized
,
const
T
*
src
,
int
sizes
,
const
int
*
ids
,
int
sizei
)
{
const
int
*
ids
,
int
sizei
)
{
CHECK_NOTNULL
(
dst
);
CHECK_NOTNULL
(
dst
);
CHECK_NOTNULL
(
src
);
CHECK_NOTNULL
(
src
);
CHECK_NOTNULL
(
ids
);
CHECK_NOTNULL
(
ids
);
...
@@ -105,18 +108,17 @@ void hl_vector_select_from(T* dst, int sized,
...
@@ -105,18 +108,17 @@ void hl_vector_select_from(T* dst, int sized,
dim3
threads
(
512
,
1
);
dim3
threads
(
512
,
1
);
dim3
grid
(
8
,
1
);
dim3
grid
(
8
,
1
);
KeVectorSelect
<
T
,
512
,
8
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeVectorSelect
<
T
,
512
,
8
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
dst
,
sized
,
src
,
sizes
,
ids
,
sizei
);
dst
,
sized
,
src
,
sizes
,
ids
,
sizei
);
CHECK_SYNC
(
"hl_vector_select_from failed"
);
CHECK_SYNC
(
"hl_vector_select_from failed"
);
}
}
template
template
void
hl_vector_select_from
(
real
*
dst
,
void
hl_vector_select_from
(
real
*
dst
,
int
sized
,
int
sized
,
const
real
*
src
,
int
sizes
,
const
real
*
src
,
const
int
*
ids
,
int
sizei
);
int
sizes
,
template
const
int
*
ids
,
void
hl_vector_select_from
(
int
*
dst
,
int
sized
,
int
sizei
);
const
int
*
src
,
int
sizes
,
template
void
hl_vector_select_from
(
const
int
*
ids
,
int
sizei
);
int
*
dst
,
int
sized
,
const
int
*
src
,
int
sizes
,
const
int
*
ids
,
int
sizei
);
paddle/cuda/src/hl_top_k.cu
浏览文件 @
1d4fa243
...
@@ -12,45 +12,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,45 +12,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "hl_base.h"
#include "hl_base.h"
#include "hl_top_k.h"
#include "hl_sparse.ph"
#include "hl_sparse.ph"
#include "hl_top_k.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Logging.h"
// using namespace hppl;
// using namespace hppl;
struct
Pair
{
struct
Pair
{
__device__
__forceinline__
__device__
__forceinline__
Pair
()
{}
Pair
()
{}
__device__
__forceinline__
__device__
__forceinline__
Pair
(
real
value
,
int
id
)
:
v_
(
value
),
id_
(
id
)
{}
Pair
(
real
value
,
int
id
)
:
v_
(
value
),
id_
(
id
)
{}
__device__
__forceinline__
__device__
__forceinline__
void
set
(
real
value
,
int
id
)
{
void
set
(
real
value
,
int
id
)
{
v_
=
value
;
v_
=
value
;
id_
=
id
;
id_
=
id
;
}
}
__device__
__forceinline__
__device__
__forceinline__
void
operator
=
(
const
Pair
&
in
)
{
void
operator
=
(
const
Pair
&
in
)
{
v_
=
in
.
v_
;
v_
=
in
.
v_
;
id_
=
in
.
id_
;
id_
=
in
.
id_
;
}
}
__device__
__forceinline__
__device__
__forceinline__
bool
operator
<
(
const
real
value
)
const
{
bool
operator
<
(
const
real
value
)
const
{
return
(
v_
<
value
);
return
(
v_
<
value
);
}
}
__device__
__forceinline__
__device__
__forceinline__
bool
operator
<
(
const
Pair
&
in
)
const
{
bool
operator
<
(
const
Pair
&
in
)
const
{
return
(
v_
<
in
.
v_
)
||
((
v_
==
in
.
v_
)
&&
(
id_
>
in
.
id_
));
return
(
v_
<
in
.
v_
)
||
((
v_
==
in
.
v_
)
&&
(
id_
>
in
.
id_
));
}
}
__device__
__forceinline__
__device__
__forceinline__
bool
operator
>
(
const
Pair
&
in
)
const
{
bool
operator
>
(
const
Pair
&
in
)
const
{
return
(
v_
>
in
.
v_
)
||
((
v_
==
in
.
v_
)
&&
(
id_
<
in
.
id_
));
return
(
v_
>
in
.
v_
)
||
((
v_
==
in
.
v_
)
&&
(
id_
<
in
.
id_
));
}
}
...
@@ -58,8 +50,9 @@ struct Pair {
...
@@ -58,8 +50,9 @@ struct Pair {
int
id_
;
int
id_
;
};
};
__device__
__forceinline__
__device__
__forceinline__
void
addTo
(
Pair
topK
[],
void
addTo
(
Pair
topK
[],
const
Pair
&
p
,
int
beamSize
)
{
const
Pair
&
p
,
int
beamSize
)
{
for
(
int
k
=
beamSize
-
2
;
k
>=
0
;
k
--
)
{
for
(
int
k
=
beamSize
-
2
;
k
>=
0
;
k
--
)
{
if
(
topK
[
k
]
<
p
)
{
if
(
topK
[
k
]
<
p
)
{
topK
[
k
+
1
]
=
topK
[
k
];
topK
[
k
+
1
]
=
topK
[
k
];
...
@@ -71,9 +64,8 @@ void addTo(Pair topK[], const Pair &p, int beamSize) {
...
@@ -71,9 +64,8 @@ void addTo(Pair topK[], const Pair &p, int beamSize) {
topK
[
0
]
=
p
;
topK
[
0
]
=
p
;
}
}
template
<
int
beamSize
>
template
<
int
beamSize
>
__device__
__forceinline__
__device__
__forceinline__
void
addTo
(
Pair
topK
[],
const
Pair
&
p
)
{
void
addTo
(
Pair
topK
[],
const
Pair
&
p
)
{
for
(
int
k
=
beamSize
-
2
;
k
>=
0
;
k
--
)
{
for
(
int
k
=
beamSize
-
2
;
k
>=
0
;
k
--
)
{
if
(
topK
[
k
]
<
p
)
{
if
(
topK
[
k
]
<
p
)
{
topK
[
k
+
1
]
=
topK
[
k
];
topK
[
k
+
1
]
=
topK
[
k
];
...
@@ -85,9 +77,9 @@ void addTo(Pair topK[], const Pair &p) {
...
@@ -85,9 +77,9 @@ void addTo(Pair topK[], const Pair &p) {
topK
[
0
]
=
p
;
topK
[
0
]
=
p
;
}
}
template
<
int
blockSize
>
template
<
int
blockSize
>
__device__
__forceinline__
__device__
__forceinline__
void
getTopK
(
void
getTopK
(
Pair
topK
[],
real
*
src
,
int
idx
,
int
dim
,
int
beamSize
)
{
Pair
topK
[],
real
*
src
,
int
idx
,
int
dim
,
int
beamSize
)
{
while
(
idx
<
dim
)
{
while
(
idx
<
dim
)
{
if
(
topK
[
beamSize
-
1
]
<
src
[
idx
])
{
if
(
topK
[
beamSize
-
1
]
<
src
[
idx
])
{
Pair
tmp
(
src
[
idx
],
idx
);
Pair
tmp
(
src
[
idx
],
idx
);
...
@@ -97,10 +89,9 @@ void getTopK(Pair topK[], real *src, int idx, int dim, int beamSize) {
...
@@ -97,10 +89,9 @@ void getTopK(Pair topK[], real *src, int idx, int dim, int beamSize) {
}
}
}
}
template
<
int
blockSize
>
template
<
int
blockSize
>
__device__
__forceinline__
__device__
__forceinline__
void
getTopK
(
void
getTopK
(
Pair
topK
[],
real
*
src
,
int
idx
,
int
dim
,
Pair
topK
[],
real
*
src
,
int
idx
,
int
dim
,
const
Pair
&
max
,
int
beamSize
)
{
const
Pair
&
max
,
int
beamSize
)
{
while
(
idx
<
dim
)
{
while
(
idx
<
dim
)
{
if
(
topK
[
beamSize
-
1
]
<
src
[
idx
])
{
if
(
topK
[
beamSize
-
1
]
<
src
[
idx
])
{
Pair
tmp
(
src
[
idx
],
idx
);
Pair
tmp
(
src
[
idx
],
idx
);
...
@@ -112,10 +103,9 @@ void getTopK(Pair topK[], real *src, int idx, int dim,
...
@@ -112,10 +103,9 @@ void getTopK(Pair topK[], real *src, int idx, int dim,
}
}
}
}
template
<
int
blockSize
>
template
<
int
blockSize
>
__device__
__forceinline__
__device__
__forceinline__
void
getTopK
(
void
getTopK
(
Pair
topK
[],
real
*
val
,
int
*
col
,
Pair
topK
[],
real
*
val
,
int
*
col
,
int
idx
,
int
dim
,
int
beamSize
)
{
int
idx
,
int
dim
,
int
beamSize
)
{
while
(
idx
<
dim
)
{
while
(
idx
<
dim
)
{
if
(
topK
[
beamSize
-
1
]
<
val
[
idx
])
{
if
(
topK
[
beamSize
-
1
]
<
val
[
idx
])
{
Pair
tmp
(
val
[
idx
],
col
[
idx
]);
Pair
tmp
(
val
[
idx
],
col
[
idx
]);
...
@@ -125,10 +115,14 @@ void getTopK(Pair topK[], real *val, int *col,
...
@@ -125,10 +115,14 @@ void getTopK(Pair topK[], real *val, int *col,
}
}
}
}
template
<
int
blockSize
>
template
<
int
blockSize
>
__device__
__forceinline__
__device__
__forceinline__
void
getTopK
(
Pair
topK
[],
void
getTopK
(
Pair
topK
[],
real
*
val
,
int
*
col
,
int
idx
,
int
dim
,
real
*
val
,
const
Pair
&
max
,
int
beamSize
)
{
int
*
col
,
int
idx
,
int
dim
,
const
Pair
&
max
,
int
beamSize
)
{
while
(
idx
<
dim
)
{
while
(
idx
<
dim
)
{
if
(
topK
[
beamSize
-
1
]
<
val
[
idx
])
{
if
(
topK
[
beamSize
-
1
]
<
val
[
idx
])
{
Pair
tmp
(
val
[
idx
],
col
[
idx
]);
Pair
tmp
(
val
[
idx
],
col
[
idx
]);
...
@@ -140,12 +134,16 @@ void getTopK(Pair topK[], real *val, int *col, int idx, int dim,
...
@@ -140,12 +134,16 @@ void getTopK(Pair topK[], real *val, int *col, int idx, int dim,
}
}
}
}
template
<
int
maxLength
,
int
blockSize
>
template
<
int
maxLength
,
int
blockSize
>
__device__
__forceinline__
__device__
__forceinline__
void
threadGetTopK
(
Pair
topK
[],
void
threadGetTopK
(
Pair
topK
[],
int
&
beam
,
int
beamSize
,
int
&
beam
,
real
*
src
,
int
beamSize
,
bool
&
firstStep
,
bool
&
isEmpty
,
Pair
&
max
,
real
*
src
,
int
dim
,
const
int
tid
)
{
bool
&
firstStep
,
bool
&
isEmpty
,
Pair
&
max
,
int
dim
,
const
int
tid
)
{
if
(
beam
>
0
)
{
if
(
beam
>
0
)
{
int
length
=
beam
<
beamSize
?
beam
:
beamSize
;
int
length
=
beam
<
beamSize
?
beam
:
beamSize
;
if
(
firstStep
)
{
if
(
firstStep
)
{
...
@@ -160,8 +158,7 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
...
@@ -160,8 +158,7 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
}
}
if
(
!
isEmpty
)
{
if
(
!
isEmpty
)
{
getTopK
<
blockSize
>
(
topK
+
maxLength
-
beam
,
src
,
tid
,
dim
,
getTopK
<
blockSize
>
(
topK
+
maxLength
-
beam
,
src
,
tid
,
dim
,
max
,
length
);
max
,
length
);
}
}
}
}
...
@@ -171,12 +168,17 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
...
@@ -171,12 +168,17 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
}
}
template
<
int
maxLength
,
int
blockSize
>
template
<
int
maxLength
,
int
blockSize
>
__device__
__forceinline__
__device__
__forceinline__
void
threadGetTopK
(
Pair
topK
[],
void
threadGetTopK
(
Pair
topK
[],
int
&
beam
,
int
beamSize
,
int
&
beam
,
real
*
val
,
int
*
col
,
int
beamSize
,
bool
&
firstStep
,
bool
&
isEmpty
,
Pair
&
max
,
real
*
val
,
int
dim
,
const
int
tid
)
{
int
*
col
,
bool
&
firstStep
,
bool
&
isEmpty
,
Pair
&
max
,
int
dim
,
const
int
tid
)
{
if
(
beam
>
0
)
{
if
(
beam
>
0
)
{
int
length
=
beam
<
beamSize
?
beam
:
beamSize
;
int
length
=
beam
<
beamSize
?
beam
:
beamSize
;
if
(
firstStep
)
{
if
(
firstStep
)
{
...
@@ -191,8 +193,8 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
...
@@ -191,8 +193,8 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
}
}
if
(
!
isEmpty
)
{
if
(
!
isEmpty
)
{
getTopK
<
blockSize
>
(
topK
+
maxLength
-
beam
,
val
,
col
,
tid
,
dim
,
getTopK
<
blockSize
>
(
max
,
length
);
topK
+
maxLength
-
beam
,
val
,
col
,
tid
,
dim
,
max
,
length
);
}
}
}
}
...
@@ -202,12 +204,16 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
...
@@ -202,12 +204,16 @@ void threadGetTopK(Pair topK[], int& beam, int beamSize,
}
}
}
}
template
<
int
maxLength
,
int
blockSize
>
template
<
int
maxLength
,
int
blockSize
>
__device__
__forceinline__
__device__
__forceinline__
void
blockReduce
(
Pair
*
shTopK
,
void
blockReduce
(
Pair
*
shTopK
,
int
*
maxId
,
Pair
topK
[],
int
*
maxId
,
real
**
topVal
,
int
**
topIds
,
Pair
topK
[],
int
&
beam
,
int
&
beamSize
,
real
**
topVal
,
const
int
tid
,
const
int
warp
)
{
int
**
topIds
,
int
&
beam
,
int
&
beamSize
,
const
int
tid
,
const
int
warp
)
{
while
(
true
)
{
while
(
true
)
{
__syncthreads
();
__syncthreads
();
if
(
tid
<
blockSize
/
2
)
{
if
(
tid
<
blockSize
/
2
)
{
...
@@ -218,7 +224,7 @@ void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
...
@@ -218,7 +224,7 @@ void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
}
}
}
}
__syncthreads
();
__syncthreads
();
for
(
int
stride
=
blockSize
/
4
;
stride
>
0
;
stride
=
stride
/
2
)
{
for
(
int
stride
=
blockSize
/
4
;
stride
>
0
;
stride
=
stride
/
2
)
{
if
(
tid
<
stride
)
{
if
(
tid
<
stride
)
{
if
(
shTopK
[
maxId
[
tid
]]
<
shTopK
[
maxId
[
tid
+
stride
]])
{
if
(
shTopK
[
maxId
[
tid
]]
<
shTopK
[
maxId
[
tid
+
stride
]])
{
maxId
[
tid
]
=
maxId
[
tid
+
stride
];
maxId
[
tid
]
=
maxId
[
tid
+
stride
];
...
@@ -257,10 +263,12 @@ void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
...
@@ -257,10 +263,12 @@ void blockReduce(Pair* shTopK, int* maxId, Pair topK[],
* 3. go to the second setp, until one thread's topK value is null;
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
* 4. go to the first setp, until get the topK value.
*/
*/
template
<
int
maxLength
,
int
blockSize
>
template
<
int
maxLength
,
int
blockSize
>
__global__
void
KeMatrixTopK
(
real
*
topVal
,
int
ldv
,
__global__
void
KeMatrixTopK
(
real
*
topVal
,
int
*
topIds
,
int
ldv
,
real
*
src
,
int
lds
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
dim
,
int
beamSize
)
{
int
beamSize
)
{
__shared__
Pair
shTopK
[
blockSize
];
__shared__
Pair
shTopK
[
blockSize
];
...
@@ -271,7 +279,7 @@ __global__ void KeMatrixTopK(real* topVal, int ldv,
...
@@ -271,7 +279,7 @@ __global__ void KeMatrixTopK(real* topVal, int ldv,
topVal
+=
blockIdx
.
x
*
ldv
;
topVal
+=
blockIdx
.
x
*
ldv
;
topIds
+=
blockIdx
.
x
*
beamSize
;
topIds
+=
blockIdx
.
x
*
beamSize
;
Pair
topK
[
maxLength
];
// NOLINT
Pair
topK
[
maxLength
];
// NOLINT
int
beam
=
maxLength
;
int
beam
=
maxLength
;
Pair
max
;
Pair
max
;
bool
isEmpty
=
false
;
bool
isEmpty
=
false
;
...
@@ -281,18 +289,19 @@ __global__ void KeMatrixTopK(real* topVal, int ldv,
...
@@ -281,18 +289,19 @@ __global__ void KeMatrixTopK(real* topVal, int ldv,
topK
[
k
].
set
(
-
HL_FLOAT_MAX
,
-
1
);
topK
[
k
].
set
(
-
HL_FLOAT_MAX
,
-
1
);
}
}
while
(
beamSize
)
{
while
(
beamSize
)
{
threadGetTopK
<
maxLength
,
blockSize
>
threadGetTopK
<
maxLength
,
blockSize
>
(
(
topK
,
beam
,
beamSize
,
src
,
firstStep
,
isEmpty
,
max
,
dim
,
tid
);
topK
,
beam
,
beamSize
,
src
,
firstStep
,
isEmpty
,
max
,
dim
,
tid
);
shTopK
[
tid
]
=
topK
[
0
];
shTopK
[
tid
]
=
topK
[
0
];
blockReduce
<
maxLength
,
blockSize
>
blockReduce
<
maxLength
,
blockSize
>
(
(
shTopK
,
maxId
,
topK
,
&
topVal
,
&
topIds
,
beam
,
beamSize
,
tid
,
warp
);
shTopK
,
maxId
,
topK
,
&
topVal
,
&
topIds
,
beam
,
beamSize
,
tid
,
warp
);
}
}
}
}
template
<
int
maxLength
,
int
blockSize
>
template
<
int
maxLength
,
int
blockSize
>
__global__
void
KeSMatrixTopK
(
real
*
topVal
,
int
ldv
,
__global__
void
KeSMatrixTopK
(
real
*
topVal
,
int
*
topIds
,
int
ldv
,
int
*
topIds
,
real
*
val
,
real
*
val
,
int
*
row
,
int
*
row
,
int
*
col
,
int
*
col
,
...
@@ -304,7 +313,7 @@ __global__ void KeSMatrixTopK(real* topVal, int ldv,
...
@@ -304,7 +313,7 @@ __global__ void KeSMatrixTopK(real* topVal, int ldv,
topVal
+=
blockIdx
.
x
*
ldv
;
topVal
+=
blockIdx
.
x
*
ldv
;
topIds
+=
blockIdx
.
x
*
beamSize
;
topIds
+=
blockIdx
.
x
*
beamSize
;
Pair
topK
[
maxLength
];
// NOLINT
Pair
topK
[
maxLength
];
// NOLINT
int
beam
=
maxLength
;
int
beam
=
maxLength
;
Pair
max
;
Pair
max
;
bool
isEmpty
=
false
;
bool
isEmpty
=
false
;
...
@@ -330,18 +339,20 @@ __global__ void KeSMatrixTopK(real* topVal, int ldv,
...
@@ -330,18 +339,20 @@ __global__ void KeSMatrixTopK(real* topVal, int ldv,
topK
[
k
].
set
(
-
HL_FLOAT_MAX
,
-
1
);
topK
[
k
].
set
(
-
HL_FLOAT_MAX
,
-
1
);
}
}
while
(
beamSize
)
{
while
(
beamSize
)
{
threadGetTopK
<
maxLength
,
blockSize
>
threadGetTopK
<
maxLength
,
blockSize
>
(
(
topK
,
beam
,
beamSize
,
val
,
col
,
firstStep
,
isEmpty
,
max
,
dim
,
tid
);
topK
,
beam
,
beamSize
,
val
,
col
,
firstStep
,
isEmpty
,
max
,
dim
,
tid
);
shTopK
[
tid
]
=
topK
[
0
];
shTopK
[
tid
]
=
topK
[
0
];
blockReduce
<
maxLength
,
blockSize
>
blockReduce
<
maxLength
,
blockSize
>
(
(
shTopK
,
maxId
,
topK
,
&
topVal
,
&
topIds
,
beam
,
beamSize
,
tid
,
warp
);
shTopK
,
maxId
,
topK
,
&
topVal
,
&
topIds
,
beam
,
beamSize
,
tid
,
warp
);
}
}
}
}
void
hl_matrix_top_k
(
real
*
topVal
,
int
ldv
,
void
hl_matrix_top_k
(
real
*
topVal
,
int
*
topIds
,
int
ldv
,
real
*
src
,
int
lds
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
dim
,
int
beamSize
,
int
beamSize
,
int
numSamples
)
{
int
numSamples
)
{
...
@@ -353,33 +364,32 @@ void hl_matrix_top_k(real* topVal, int ldv,
...
@@ -353,33 +364,32 @@ void hl_matrix_top_k(real* topVal, int ldv,
dim3
threads
(
256
,
1
);
dim3
threads
(
256
,
1
);
dim3
grid
(
numSamples
,
1
);
dim3
grid
(
numSamples
,
1
);
KeMatrixTopK
<
5
,
256
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeMatrixTopK
<
5
,
256
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
topVal
,
ldv
,
topIds
,
src
,
lds
,
dim
,
beamSize
);
topVal
,
ldv
,
topIds
,
src
,
lds
,
dim
,
beamSize
);
CHECK_SYNC
(
"hl_matrix_top_k failed"
);
CHECK_SYNC
(
"hl_matrix_top_k failed"
);
}
}
void
hl_sparse_matrix_top_k
(
real
*
topVal
,
int
ldv
,
void
hl_sparse_matrix_top_k
(
real
*
topVal
,
int
*
topIds
,
int
ldv
,
int
*
topIds
,
hl_sparse_matrix_s
src
,
hl_sparse_matrix_s
src
,
int
beamSize
,
int
beamSize
,
int
numSamples
)
{
int
numSamples
)
{
CHECK_NOTNULL
(
topVal
);
CHECK_NOTNULL
(
topVal
);
CHECK_NOTNULL
(
topIds
);
CHECK_NOTNULL
(
topIds
);
CHECK_NOTNULL
(
src
);
CHECK_NOTNULL
(
src
);
CHECK_EQ
(
src
->
format
,
HL_SPARSE_CSR
)
CHECK_EQ
(
src
->
format
,
HL_SPARSE_CSR
)
<<
"sparse matrix format error!"
;
<<
"sparse matrix format error!"
;
hl_csr_matrix
csr
=
(
hl_csr_matrix
)
src
->
matrix
;
hl_csr_matrix
csr
=
(
hl_csr_matrix
)
src
->
matrix
;
if
(
csr
->
csr_val
==
NULL
||
csr
->
csr_row
==
NULL
||
if
(
csr
->
csr_val
==
NULL
||
csr
->
csr_row
==
NULL
||
csr
->
csr_col
==
NULL
)
{
csr
->
csr_col
==
NULL
)
{
LOG
(
FATAL
)
<<
"parameter src is null!"
;
LOG
(
FATAL
)
<<
"parameter src is null!"
;
}
}
dim3
threads
(
256
,
1
);
dim3
threads
(
256
,
1
);
dim3
grid
(
numSamples
,
1
);
dim3
grid
(
numSamples
,
1
);
KeSMatrixTopK
<
5
,
256
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
KeSMatrixTopK
<
5
,
256
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
(
topVal
,
ldv
,
topIds
,
csr
->
csr_val
,
csr
->
csr_row
,
csr
->
csr_col
,
beamSize
);
topVal
,
ldv
,
topIds
,
csr
->
csr_val
,
csr
->
csr_row
,
csr
->
csr_col
,
beamSize
);
CHECK_SYNC
(
"hl_sparse_matrix_top_k failed"
);
CHECK_SYNC
(
"hl_sparse_matrix_top_k failed"
);
}
}
...
@@ -392,10 +402,12 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
...
@@ -392,10 +402,12 @@ void hl_sparse_matrix_top_k(real* topVal, int ldv,
* 3. go to the second setp, until one thread's topK value is null;
* 3. go to the second setp, until one thread's topK value is null;
* 4. go to the first setp, until get the topK value.
* 4. go to the first setp, until get the topK value.
*/
*/
template
<
int
maxLength
,
int
blockSize
>
template
<
int
maxLength
,
int
blockSize
>
__global__
void
KeMatrixTopKClassificationError
(
real
*
topVal
,
int
ldv
,
__global__
void
KeMatrixTopKClassificationError
(
real
*
topVal
,
int
*
topIds
,
int
ldv
,
real
*
src
,
int
lds
,
int
*
topIds
,
real
*
src
,
int
lds
,
int
dim
,
int
dim
,
int
beamSize
,
int
beamSize
,
int
*
label
,
int
*
label
,
...
@@ -408,7 +420,7 @@ __global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
...
@@ -408,7 +420,7 @@ __global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
topVal
+=
blockIdx
.
x
*
ldv
;
topVal
+=
blockIdx
.
x
*
ldv
;
topIds
+=
blockIdx
.
x
*
beamSize
;
topIds
+=
blockIdx
.
x
*
beamSize
;
Pair
topK
[
maxLength
];
// NOLINT
Pair
topK
[
maxLength
];
// NOLINT
int
beam
=
maxLength
;
int
beam
=
maxLength
;
Pair
max
;
Pair
max
;
bool
isEmpty
=
false
;
bool
isEmpty
=
false
;
...
@@ -420,34 +432,36 @@ __global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
...
@@ -420,34 +432,36 @@ __global__ void KeMatrixTopKClassificationError(real* topVal, int ldv,
}
}
while
(
beamSize
)
{
while
(
beamSize
)
{
threadGetTopK
<
maxLength
,
blockSize
>
threadGetTopK
<
maxLength
,
blockSize
>
(
(
topK
,
beam
,
beamSize
,
src
,
firstStep
,
isEmpty
,
max
,
dim
,
tid
);
topK
,
beam
,
beamSize
,
src
,
firstStep
,
isEmpty
,
max
,
dim
,
tid
);
shTopK
[
tid
]
=
topK
[
0
];
shTopK
[
tid
]
=
topK
[
0
];
blockReduce
<
maxLength
,
blockSize
>
blockReduce
<
maxLength
,
blockSize
>
(
(
shTopK
,
maxId
,
topK
,
&
topVal
,
&
topIds
,
beam
,
beamSize
,
tid
,
warp
);
shTopK
,
maxId
,
topK
,
&
topVal
,
&
topIds
,
beam
,
beamSize
,
tid
,
warp
);
}
}
__syncthreads
();
__syncthreads
();
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
for
(
int
i
=
0
;
i
<
topkSize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
topkSize
;
i
++
)
{
if
(
*--
topIds
==
label
[
blockIdx
.
x
])
{
if
(
*--
topIds
==
label
[
blockIdx
.
x
])
{
recResult
[
blockIdx
.
x
]
=
0
;
recResult
[
blockIdx
.
x
]
=
0
;
break
;
break
;
}
}
recResult
[
blockIdx
.
x
]
=
1.0
f
;
recResult
[
blockIdx
.
x
]
=
1.0
f
;
}
}
}
}
}
}
void
hl_matrix_classification_error
(
real
*
topVal
,
int
ldv
,
void
hl_matrix_classification_error
(
real
*
topVal
,
int
*
topIds
,
int
ldv
,
real
*
src
,
int
lds
,
int
*
topIds
,
int
dim
,
real
*
src
,
int
topkSize
,
int
lds
,
int
numSamples
,
int
dim
,
int
*
label
,
int
topkSize
,
real
*
recResult
)
{
int
numSamples
,
int
*
label
,
real
*
recResult
)
{
CHECK_NOTNULL
(
topVal
);
CHECK_NOTNULL
(
topVal
);
CHECK_NOTNULL
(
topIds
);
CHECK_NOTNULL
(
topIds
);
CHECK_NOTNULL
(
src
);
CHECK_NOTNULL
(
src
);
...
@@ -456,9 +470,8 @@ void hl_matrix_classification_error(real* topVal, int ldv,
...
@@ -456,9 +470,8 @@ void hl_matrix_classification_error(real* topVal, int ldv,
dim3
threads
(
256
,
1
);
dim3
threads
(
256
,
1
);
dim3
grid
(
numSamples
,
1
);
dim3
grid
(
numSamples
,
1
);
KeMatrixTopKClassificationError
<
5
,
256
>
KeMatrixTopKClassificationError
<
5
,
256
><<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
topVal
,
ldv
,
topIds
,
src
,
lds
,
dim
,
topkSize
,
label
,
recResult
);
(
topVal
,
ldv
,
topIds
,
src
,
lds
,
dim
,
topkSize
,
label
,
recResult
);
CHECK_SYNC
(
"hl_matrix_top_k classification error failed"
);
CHECK_SYNC
(
"hl_matrix_top_k classification error failed"
);
}
}
paddle/framework/attr_type.proto
浏览文件 @
1d4fa243
...
@@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
syntax
=
"proto2"
;
syntax
=
"proto2"
;
package
paddle
.
framework
;
package
paddle
.
framework
;
// Attribute Type for paddle's Op.
// Attribute Type for paddle's Op.
// Op contains many attributes. Each type of attributes could be different.
// Op contains many attributes. Each type of attributes could be different.
// The AttrType will be shared between AttrDesc and AttrProto.
// The AttrType will be shared between AttrDesc and AttrProto.
enum
AttrType
{
enum
AttrType
{
INT
=
0
;
INT
=
0
;
FLOAT
=
1
;
FLOAT
=
1
;
STRING
=
2
;
STRING
=
2
;
INTS
=
3
;
INTS
=
3
;
FLOATS
=
4
;
FLOATS
=
4
;
STRINGS
=
5
;
STRINGS
=
5
;
}
}
\ No newline at end of file
paddle/framework/op_desc.proto
浏览文件 @
1d4fa243
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
syntax
=
"proto2"
;
syntax
=
"proto2"
;
package
paddle
.
framework
;
package
paddle
.
framework
;
import
"attr_type.proto"
;
import
"attr_type.proto"
;
...
@@ -22,14 +22,14 @@ import "attr_type.proto";
...
@@ -22,14 +22,14 @@ import "attr_type.proto";
//
//
// e.g, for scale=3.0: name=scala, type=AttrType.FLOAT, value=3.0
// e.g, for scale=3.0: name=scala, type=AttrType.FLOAT, value=3.0
message
AttrDesc
{
message
AttrDesc
{
required
string
name
=
1
;
required
string
name
=
1
;
required
AttrType
type
=
2
;
required
AttrType
type
=
2
;
optional
int32
i
=
3
;
optional
int32
i
=
3
;
optional
float
f
=
4
;
optional
float
f
=
4
;
optional
string
s
=
5
;
optional
string
s
=
5
;
repeated
int32
ints
=
6
;
repeated
int32
ints
=
6
;
repeated
float
floats
=
7
;
repeated
float
floats
=
7
;
repeated
string
strings
=
8
;
repeated
string
strings
=
8
;
};
};
// Protocol Message to describe an Operator.
// Protocol Message to describe an Operator.
...
@@ -42,15 +42,15 @@ message AttrDesc {
...
@@ -42,15 +42,15 @@ message AttrDesc {
// 3rd-party language can build this proto message and call
// 3rd-party language can build this proto message and call
// AddOp(const OpDesc& op_desc) of Paddle core to create an Operator.
// AddOp(const OpDesc& op_desc) of Paddle core to create an Operator.
message
OpDesc
{
message
OpDesc
{
// input names of this Operator.
// input names of this Operator.
repeated
string
inputs
=
1
;
repeated
string
inputs
=
1
;
// output names of this Operator.
// output names of this Operator.
repeated
string
outputs
=
2
;
repeated
string
outputs
=
2
;
// type of this Operator, such as "add", "sub", "fc".
// type of this Operator, such as "add", "sub", "fc".
required
string
type
=
3
;
required
string
type
=
3
;
// Attributes of this Operator. e.g., scale=3.0 in cosine op.
// Attributes of this Operator. e.g., scale=3.0 in cosine op.
repeated
AttrDesc
attrs
=
4
;
repeated
AttrDesc
attrs
=
4
;
};
};
\ No newline at end of file
paddle/framework/op_proto.proto
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/function/ContextProjectionOpGpu.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/function/CosSimOpGpu.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/function/CropOpGpu.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/function/CrossMapNormalOpGpu.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/function/DepthwiseConvOpGpu.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/function/Im2ColOpGpu.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/function/MulOpGpu.cu
浏览文件 @
1d4fa243
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "hl_base.h"
#include "MulOp.h"
#include "MulOp.h"
#include "hl_base.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
#include "paddle/math/SparseMatrix.h"
...
...
paddle/function/PadOpGpu.cu
浏览文件 @
1d4fa243
...
@@ -12,15 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,15 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "hl_base.h"
#include "PadOp.h"
#include "PadOp.h"
#include "hl_base.h"
namespace
paddle
{
namespace
paddle
{
__global__
void
KePad
(
real
*
outputs
,
const
real
*
inputs
,
__global__
void
KePad
(
real
*
outputs
,
int
inC
,
int
inH
,
int
inW
,
const
real
*
inputs
,
int
padc
,
int
padh
,
int
padw
,
int
inC
,
int
outC
,
int
outH
,
int
outW
,
int
nthreads
)
{
int
inH
,
int
inW
,
int
padc
,
int
padh
,
int
padw
,
int
outC
,
int
outH
,
int
outW
,
int
nthreads
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
nthreads
)
{
if
(
idx
<
nthreads
)
{
const
int
w
=
idx
%
inW
;
const
int
w
=
idx
%
inW
;
...
@@ -50,16 +58,33 @@ void Pad<DEVICE_TYPE_GPU>(real* outputs,
...
@@ -50,16 +58,33 @@ void Pad<DEVICE_TYPE_GPU>(real* outputs,
int
outC
=
inC
+
cstart
+
cend
;
int
outC
=
inC
+
cstart
+
cend
;
int
outH
=
inH
+
hstart
+
hend
;
int
outH
=
inH
+
hstart
+
hend
;
int
outW
=
inW
+
wstart
+
wend
;
int
outW
=
inW
+
wstart
+
wend
;
KePad
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
KePad
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
outputs
,
(
outputs
,
inputs
,
inC
,
inH
,
inW
,
cstart
,
hstart
,
wstart
,
inputs
,
outC
,
outH
,
outW
,
nth
);
inC
,
inH
,
inW
,
cstart
,
hstart
,
wstart
,
outC
,
outH
,
outW
,
nth
);
CHECK_SYNC
(
"Pad"
);
CHECK_SYNC
(
"Pad"
);
}
}
__global__
void
KePadDiff
(
real
*
inGrad
,
const
real
*
outGrad
,
__global__
void
KePadDiff
(
real
*
inGrad
,
int
inC
,
int
inH
,
int
inW
,
const
real
*
outGrad
,
int
padc
,
int
padh
,
int
padw
,
int
inC
,
int
outC
,
int
outH
,
int
outW
,
int
nthreads
)
{
int
inH
,
int
inW
,
int
padc
,
int
padh
,
int
padw
,
int
outC
,
int
outH
,
int
outW
,
int
nthreads
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
nthreads
)
{
if
(
idx
<
nthreads
)
{
const
int
w
=
idx
%
inW
;
const
int
w
=
idx
%
inW
;
...
@@ -89,9 +114,18 @@ void PadGrad<DEVICE_TYPE_GPU>(real* inGrad,
...
@@ -89,9 +114,18 @@ void PadGrad<DEVICE_TYPE_GPU>(real* inGrad,
int
outC
=
inC
+
cstart
+
cend
;
int
outC
=
inC
+
cstart
+
cend
;
int
outH
=
inH
+
hstart
+
hend
;
int
outH
=
inH
+
hstart
+
hend
;
int
outW
=
inW
+
wstart
+
wend
;
int
outW
=
inW
+
wstart
+
wend
;
KePadDiff
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
KePadDiff
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
inGrad
,
(
inGrad
,
outGrad
,
inC
,
inH
,
inW
,
cstart
,
hstart
,
wstart
,
outGrad
,
outC
,
outH
,
outW
,
nth
);
inC
,
inH
,
inW
,
cstart
,
hstart
,
wstart
,
outC
,
outH
,
outW
,
nth
);
CHECK_SYNC
(
"PadGrad"
);
CHECK_SYNC
(
"PadGrad"
);
}
}
...
...
paddle/function/RowConvOpGpu.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/gserver/layers/GruCompute.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/gserver/layers/LstmCompute.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/math/BaseMatrix.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/math/TrainingAlgorithmOp.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/math/tests/test_Tensor.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/math/tests/test_lazyAssign.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/operators/softmax_op.cu
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
paddle/trainer/tests/pydata_provider_wrapper_dir/test_pydata_provider_wrapper.proto
浏览文件 @
1d4fa243
无法预览此类型文件
proto/DataConfig.proto
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
proto/DataFormat.proto
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
proto/ModelConfig.proto
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
proto/OptimizerConfig.proto
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
proto/ParameterConfig.proto
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
proto/ParameterServerConfig.proto
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
proto/ParameterService.proto
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
proto/TrainerConfig.proto
浏览文件 @
1d4fa243
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录