Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b818429a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b818429a
编写于
11月 28, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
11月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize cumsum OP (#29193)
上级
27b42183
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
196 addition
and
253 deletion
+196
-253
paddle/fluid/operators/cumsum_op.cu
paddle/fluid/operators/cumsum_op.cu
+196
-253
未找到文件。
paddle/fluid/operators/cumsum_op.cu
浏览文件 @
b818429a
...
...
@@ -14,8 +14,10 @@ limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/gather.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/cum_op.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
...
...
@@ -25,223 +27,157 @@ using LoDTensor = paddle::framework::LoDTensor;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
OuterScan
(
const
T
*
in
,
T
*
out
,
int
inner_dim_size
,
int
outer_dim_size
,
int
scan_dim_size
,
bool
exclusive
,
bool
reverse
)
{
int
id
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
outer_index
=
blockIdx
.
x
;
outer_index
<
outer_dim_size
;
outer_index
+=
gridDim
.
x
)
{
for
(
int
inner_index
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
inner_index
<
inner_dim_size
;
inner_index
+=
gridDim
.
y
*
blockDim
.
x
)
{
int
scan_index_init
=
0
;
int
forward_direction
=
1
;
int
src_index
=
outer_index
*
scan_dim_size
*
inner_dim_size
+
inner_index
;
int
dst_index
=
outer_index
*
scan_dim_size
*
inner_dim_size
+
inner_index
;
if
(
reverse
)
{
src_index
=
src_index
+
(
scan_dim_size
-
1
)
*
inner_dim_size
;
dst_index
=
dst_index
+
(
scan_dim_size
-
1
)
*
inner_dim_size
;
forward_direction
=
-
1
;
}
if
(
exclusive
)
{
scan_index_init
=
1
;
out
[
dst_index
]
=
0
;
dst_index
=
dst_index
+
(
forward_direction
*
inner_dim_size
);
}
T
acc
=
0
;
for
(
int
scan_index
=
scan_index_init
;
scan_index
<
scan_dim_size
;
++
scan_index
)
{
acc
=
in
[
src_index
]
+
acc
;
out
[
dst_index
]
=
acc
;
src_index
+=
(
forward_direction
*
inner_dim_size
);
dst_index
+=
(
forward_direction
*
inner_dim_size
);
}
}
template
<
typename
T
,
int
BLOCK_SIZE
>
__device__
void
BlockReverse
(
const
T
*
idata
,
T
*
odata
,
int
src_base
,
int
dst_base
,
int
valid_item
)
{
__shared__
T
sh_mem
[
BLOCK_SIZE
];
int
tx
=
threadIdx
.
x
;
int
offset
=
tx
;
int
in_index
=
src_base
+
offset
;
if
(
offset
>=
valid_item
)
{
sh_mem
[
offset
]
=
0
;
}
else
{
int
sh_mem_index
=
BLOCK_SIZE
-
offset
-
1
;
T
data
=
idata
[
in_index
];
sh_mem
[
sh_mem_index
]
=
data
;
}
__syncthreads
();
int
out_index
=
dst_base
-
offset
;
if
(
offset
<
valid_item
)
{
int
sh_mem_index
=
BLOCK_SIZE
-
offset
-
1
;
odata
[
out_index
]
=
sh_mem
[
sh_mem_index
];
}
}
// inclusive scan
template
<
typename
T
,
int
num_threads_x
,
int
num_threads_y
>
__global__
void
InnerMostDimInclusiveScan
(
const
T
*
in
,
T
*
out
,
int
inner_dim_size
,
int
outer_dim_size
,
int
scan_dim_size
,
bool
reverse
)
{
__shared__
T
share_data
[
num_threads_y
][
num_threads_x
*
2
];
T
*
share_row
=
share_data
[
threadIdx
.
y
];
int
forward_direction
=
1
;
if
(
reverse
)
forward_direction
=
-
1
;
for
(
int
block_row
=
blockIdx
.
x
*
blockDim
.
y
;
block_row
<
outer_dim_size
;
block_row
+=
blockDim
.
y
*
gridDim
.
x
)
{
int
row
=
block_row
+
threadIdx
.
y
;
T
acc
=
0
;
const
T
*
row_src
=
in
+
row
*
scan_dim_size
;
T
*
row_dst
=
out
+
row
*
scan_dim_size
;
int
block_col
=
0
;
bool
loop_condition
=
(
block_col
<
scan_dim_size
);
if
(
reverse
)
{
loop_condition
=
(
block_col
>=
0
);
block_col
=
scan_dim_size
-
1
;
template
<
typename
T
>
__global__
void
MatrixRowReverse
(
const
T
*
matrix_data
,
T
*
reverse_data
,
int
reverse_size
,
int
outer_size
,
int
inner_size
)
{
int
bx
=
blockIdx
.
x
;
int
by
=
blockIdx
.
y
;
int
item_per_block
=
1024
;
for
(
int
block_offset
=
0
;
block_offset
<
reverse_size
;
block_offset
+=
item_per_block
)
{
int
valid_item
=
(
reverse_size
-
block_offset
>
item_per_block
)
?
item_per_block
:
reverse_size
-
block_offset
;
int
src_offset
=
bx
*
reverse_size
+
block_offset
+
by
*
(
inner_size
*
reverse_size
);
int
dst_offset
=
bx
*
reverse_size
+
by
*
(
inner_size
*
reverse_size
)
+
reverse_size
-
1
-
block_offset
;
if
(
reverse_size
<
item_per_block
)
{
valid_item
=
reverse_size
;
}
while
(
loop_condition
)
{
// Load data into share memory(two value per thread)
int
col1
=
block_col
+
threadIdx
.
x
*
forward_direction
;
int
col2
=
block_col
+
(
num_threads_x
+
threadIdx
.
x
)
*
forward_direction
;
if
(
row
<
outer_dim_size
)
{
if
(
col1
<
scan_dim_size
&&
col1
>=
0
)
{
share_row
[
threadIdx
.
x
]
=
row_src
[
col1
];
}
else
{
share_row
[
threadIdx
.
x
]
=
0
;
}
if
(
col2
<
scan_dim_size
&&
col2
>=
0
)
{
share_row
[
num_threads_x
+
threadIdx
.
x
]
=
row_src
[
col2
];
}
else
{
share_row
[
num_threads_x
+
threadIdx
.
x
]
=
0
;
}
// Add the previous block acc to the result
if
(
threadIdx
.
x
==
0
)
{
share_row
[
0
]
=
share_row
[
0
]
+
acc
;
}
}
__syncthreads
();
// Up-Sweep
for
(
unsigned
s
=
num_threads_x
,
d
=
1
;
s
>=
1
;
s
>>=
1
,
d
<<=
1
)
{
if
(
row
<
outer_dim_size
&&
threadIdx
.
x
<
s
)
{
unsigned
offset
=
(
2
*
threadIdx
.
x
+
1
)
*
d
-
1
;
share_row
[
offset
+
d
]
=
share_row
[
offset
]
+
share_row
[
offset
+
d
];
}
__syncthreads
();
}
// Down-Sweep
for
(
unsigned
s
=
2
,
d
=
blockDim
.
x
/
2
;
d
>=
1
;
s
<<=
1
,
d
>>=
1
)
{
if
(
row
<
outer_dim_size
&&
threadIdx
.
x
<
s
-
1
)
{
unsigned
offset
=
2
*
(
threadIdx
.
x
+
1
)
*
d
-
1
;
share_row
[
offset
+
d
]
=
share_row
[
offset
]
+
share_row
[
offset
+
d
];
}
__syncthreads
();
}
// Write to the output
if
(
row
<
outer_dim_size
)
{
if
(
col1
<
scan_dim_size
&&
col1
>=
0
)
row_dst
[
col1
]
=
share_row
[
threadIdx
.
x
];
if
(
col2
<
scan_dim_size
&&
col2
>=
0
)
row_dst
[
col2
]
=
share_row
[
num_threads_x
+
threadIdx
.
x
];
}
acc
=
share_row
[
2
*
num_threads_x
-
1
];
__syncthreads
();
block_col
+=
2
*
num_threads_x
*
forward_direction
;
if
(
reverse
)
loop_condition
=
(
block_col
>=
0
);
else
loop_condition
=
(
block_col
<
scan_dim_size
);
}
BlockReverse
<
T
,
1024
>
(
matrix_data
,
reverse_data
,
src_offset
,
dst_offset
,
valid_item
);
}
}
// exclusive block scan and store block sum for large scan
template
<
typename
T
>
__global__
void
InnerMostDimExclusiveScan
(
const
T
*
in
,
T
*
out
,
T
*
sum_data
,
int
inner_dim_size
,
int
outer_dim_size
,
int
scan_dim_size
,
int
two_power
,
bool
reverse
)
{
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern
__shared__
__align__
(
sizeof
(
T
))
unsigned
char
raw_tmp
[];
T
*
share_tmp
=
reinterpret_cast
<
T
*>
(
raw_tmp
);
int
thread_id
=
threadIdx
.
x
;
int
block_id
=
blockIdx
.
x
;
int
block_scan_size
=
blockDim
.
x
*
2
;
int
remain
=
scan_dim_size
%
(
2
*
blockDim
.
x
);
if
(
block_id
==
gridDim
.
x
-
1
&&
remain
!=
0
)
block_scan_size
=
remain
;
int
col1
=
thread_id
;
int
col2
=
thread_id
+
(
block_scan_size
)
/
2
;
int
index1
=
blockIdx
.
y
*
(
scan_dim_size
)
+
block_id
*
blockDim
.
x
*
2
+
col1
;
int
index2
=
blockIdx
.
y
*
(
scan_dim_size
)
+
block_id
*
blockDim
.
x
*
2
+
col2
;
if
(
reverse
)
{
index1
=
blockIdx
.
y
*
(
scan_dim_size
)
+
scan_dim_size
-
1
-
(
block_id
*
blockDim
.
x
*
2
+
col1
);
index2
=
blockIdx
.
y
*
(
scan_dim_size
)
+
scan_dim_size
-
1
-
(
block_id
*
blockDim
.
x
*
2
+
col2
);
}
int
sum_index
=
blockIdx
.
y
*
gridDim
.
x
+
block_id
;
if
(
thread_id
<
block_scan_size
)
{
share_tmp
[
col1
+
(
col1
>>
5
)]
=
in
[
index1
];
share_tmp
[
col2
+
(
col2
>>
5
)]
=
in
[
index2
];
}
else
{
share_tmp
[
col1
+
(
col1
>>
5
)]
=
0
;
share_tmp
[
col2
+
(
col2
>>
5
)]
=
0
;
struct
BlockPrefixCallbackOp
{
// Running prefix
T
running_total
;
// Constructor
__device__
BlockPrefixCallbackOp
(
T
running_total
)
:
running_total
(
running_total
)
{}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__
T
operator
()(
T
block_aggregate
)
{
T
old_prefix
=
running_total
;
running_total
=
old_prefix
+
block_aggregate
;
return
old_prefix
;
}
};
// Up-Sweep
int
offset
=
1
;
for
(
int
d
=
(
two_power
/
2
);
d
>
0
;
d
>>=
1
)
{
__syncthreads
();
if
(
thread_id
<
d
)
{
int
tmp_index1
=
offset
*
(
2
*
thread_id
+
1
)
-
1
;
int
tmp_index2
=
offset
*
(
2
*
thread_id
+
2
)
-
1
;
tmp_index1
=
tmp_index1
+
(
tmp_index1
>>
5
);
tmp_index2
=
tmp_index2
+
(
tmp_index2
>>
5
);
share_tmp
[
tmp_index2
]
+=
share_tmp
[
tmp_index1
];
// No bank-conflict transpose
// Same as transposeCoalesced except the first tile dimension is padded
// to avoid shared memory bank conflicts.
template
<
typename
T
,
int
TILE_DIM
,
int
BLOCK_ROWS
>
__global__
void
MatrixTranspose
(
T
*
odata
,
const
T
*
idata
,
size_t
height
,
size_t
width
)
{
__shared__
T
tile
[
TILE_DIM
][
TILE_DIM
+
1
];
int
x
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
x
;
int
y
=
blockIdx
.
y
*
TILE_DIM
+
threadIdx
.
y
;
for
(
int
j
=
0
;
j
<
TILE_DIM
;
j
+=
BLOCK_ROWS
)
{
if
(
x
<
width
&&
(
y
+
j
)
<
height
)
{
tile
[
threadIdx
.
y
+
j
][
threadIdx
.
x
]
=
idata
[(
y
+
j
)
*
width
+
x
];
}
else
{
tile
[
threadIdx
.
y
+
j
][
threadIdx
.
x
]
=
0
;
}
offset
*=
2
;
}
__syncthreads
();
if
(
thread_id
==
0
)
{
int
tmp_index
=
(
two_power
-
1
)
+
((
two_power
-
1
)
>>
5
);
sum_data
[
sum_index
]
=
share_tmp
[
tmp_index
];
share_tmp
[
tmp_index
]
=
0
;
}
x
=
blockIdx
.
y
*
TILE_DIM
+
threadIdx
.
x
;
// transpose block offset
y
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
// Down Sweep
for
(
int
d
=
1
;
d
<
two_power
;
d
*=
2
)
{
offset
>>=
1
;
__syncthreads
();
if
(
thread_id
<
d
)
{
int
tmp_index1
=
offset
*
(
2
*
thread_id
+
1
)
-
1
;
int
tmp_index2
=
offset
*
(
2
*
thread_id
+
2
)
-
1
;
tmp_index1
=
tmp_index1
+
(
tmp_index1
>>
5
);
tmp_index2
=
tmp_index2
+
(
tmp_index2
>>
5
);
T
tmp
=
share_tmp
[
tmp_index1
];
share_tmp
[
tmp_index1
]
=
share_tmp
[
tmp_index2
];
share_tmp
[
tmp_index2
]
+=
tmp
;
for
(
int
j
=
0
;
j
<
TILE_DIM
;
j
+=
BLOCK_ROWS
)
{
if
(
x
<
height
&&
(
y
+
j
)
<
width
)
{
odata
[(
y
+
j
)
*
height
+
x
]
=
tile
[
threadIdx
.
x
][
threadIdx
.
y
+
j
];
}
}
}
__syncthreads
();
template
<
typename
T
,
int
BLOCK_THREADS
,
int
ITEMS_PER_THREAD
>
__global__
void
BlockScanKernel
(
T
*
d_out
,
const
T
*
d_in
,
int
inner_size
,
int
outer_size
,
int
scan_size
,
bool
exclusive
)
{
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef
cub
::
BlockLoad
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_LOAD_TRANSPOSE
>
BlockLoadT
;
typedef
cub
::
BlockStore
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_STORE_TRANSPOSE
>
BlockStoreT
;
typedef
cub
::
BlockScan
<
T
,
BLOCK_THREADS
>
BlockScanT
;
// Allocate type-safe, repurposable shared memory for collectives
__shared__
union
{
typename
BlockLoadT
::
TempStorage
load
;
typename
BlockStoreT
::
TempStorage
store
;
typename
BlockScanT
::
TempStorage
scan
;
}
temp_storage
;
int
bx
=
blockIdx
.
x
;
int
by
=
blockIdx
.
y
;
BlockPrefixCallbackOp
<
T
>
prefix_op
(
0
);
T
block_aggregate
=
static_cast
<
T
>
(
0
);
// Obtain this block's segment of consecutive keys (blocked across threads)
int
item_per_block
=
BLOCK_THREADS
*
ITEMS_PER_THREAD
;
for
(
int
block_offset
=
0
;
block_offset
<
scan_size
;
block_offset
+=
BLOCK_THREADS
*
ITEMS_PER_THREAD
)
{
int
valid_item
=
(
scan_size
-
block_offset
>
item_per_block
)
?
item_per_block
:
(
scan_size
-
block_offset
);
if
(
scan_size
<
item_per_block
)
{
valid_item
=
scan_size
;
}
if
(
col1
<
block_scan_size
)
out
[
index1
]
=
share_tmp
[
col1
+
(
col1
>>
5
)];
if
(
col2
<
block_scan_size
)
out
[
index2
]
=
share_tmp
[
col2
+
(
col2
>>
5
)];
}
int
offset
=
bx
*
scan_size
+
block_offset
+
by
*
(
inner_size
*
scan_size
);
// for large scan_dim_size array we need to add for correct result
template
<
typename
T
>
__global__
void
AddBlockScan
(
T
*
result
,
T
*
sum
,
int
size
,
int
scan_dim_size
,
int
sum_size
,
bool
reverse
)
{
int
idx
=
threadIdx
.
x
+
blockDim
.
x
*
(
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
);
int
block_id_start
=
blockIdx
.
y
*
sum_size
;
int
block_id_end
=
blockIdx
.
x
+
blockIdx
.
y
*
sum_size
;
int
block_id
=
blockIdx
.
x
;
int
thread_id
=
threadIdx
.
x
;
int
col
=
block_id
*
blockDim
.
x
+
thread_id
+
size
;
int
index
=
blockIdx
.
y
*
(
scan_dim_size
)
+
col
;
if
(
reverse
)
{
index
=
blockIdx
.
y
*
(
scan_dim_size
)
+
scan_dim_size
-
1
-
col
;
}
T
thread_keys
[
ITEMS_PER_THREAD
];
BlockLoadT
(
temp_storage
.
load
)
.
Load
(
d_in
+
offset
,
thread_keys
,
valid_item
,
0
);
if
(
col
>=
scan_dim_size
||
col
<
0
)
return
;
for
(
int
i
=
block_id_start
;
i
<=
block_id_end
;
i
++
)
{
result
[
index
]
+=
sum
[
i
];
__syncthreads
();
if
(
exclusive
)
{
T
init_value
=
static_cast
<
T
>
(
0
);
BlockScanT
(
temp_storage
.
scan
)
.
ExclusiveScan
(
thread_keys
,
thread_keys
,
cub
::
Sum
(),
prefix_op
);
}
else
{
BlockScanT
(
temp_storage
.
scan
)
.
InclusiveScan
(
thread_keys
,
thread_keys
,
cub
::
Sum
(),
prefix_op
);
}
__syncthreads
();
BlockStoreT
(
temp_storage
.
store
)
.
Store
(
d_out
+
offset
,
thread_keys
,
valid_item
);
}
}
...
...
@@ -298,72 +234,79 @@ class CumCUDAKernel : public framework::OpKernel<T> {
return
;
}
const
int
&
scan_dim_size
=
out_dims
[
axis
];
bool
optimize_condition
=
(
axis
==
(
out_dims
.
size
()
-
1
))
?
true
:
false
;
int
outer_dim_size
=
1
;
int
inner_dim_size
=
1
;
// treat all dim index < axis as outer_dim_size
for
(
size_t
i
=
0
;
i
<
axis
;
i
++
)
{
outer_dim_size
*=
out_dims
[
i
];
size_t
height
=
1
;
size_t
width
=
1
;
for
(
size_t
i
=
0
;
i
<=
axis
;
i
++
)
{
height
*=
out_dims
[
i
];
}
// treat all dim index > axis as innner_dim_size
for
(
size_t
i
=
axis
+
1
;
i
<
out_dims
.
size
();
i
++
)
{
inner_dim_size
*=
out_dims
[
i
];
width
*=
out_dims
[
i
];
}
int
scan_size
=
out_dims
[
axis
];
bool
transpose
=
(
axis
!=
out_dims
.
size
()
-
1
);
int
tile_size
=
32
;
dim3
blocks
(
32
,
8
);
dim3
transpose_grids
((
width
+
tile_size
-
1
)
/
tile_size
,
(
height
+
tile_size
-
1
)
/
tile_size
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
if
(
optimize_condition
)
{
auto
nextPowerOfTwo
=
[](
int
x
)
->
int
{
int
ret
=
1
;
while
(
ret
<
x
)
ret
=
ret
*
2
;
return
ret
;
};
if
(
exclusive
)
{
int
element_per_block
=
nextPowerOfTwo
(
scan_dim_size
)
/
2
;
if
(
element_per_block
>
512
||
element_per_block
<
32
)
{
element_per_block
=
64
;
}
int
two_power
=
element_per_block
*
2
;
dim3
block
(
element_per_block
);
dim3
grid
(((
scan_dim_size
+
1
)
/
2
+
block
.
x
-
1
)
/
block
.
x
,
outer_dim_size
);
int
offset_size
=
(
element_per_block
*
2
)
>>
5
;
int
share_mem_size
=
(
element_per_block
*
2
+
offset_size
)
*
sizeof
(
T
);
Tensor
scan_sum
;
paddle
::
framework
::
DDim
dims
{
((
scan_dim_size
+
1
)
/
2
+
block
.
x
-
1
)
/
block
.
x
,
outer_dim_size
};
scan_sum
.
Resize
(
dims
);
T
*
sum_data
=
scan_sum
.
mutable_data
<
T
>
(
context
.
GetPlace
());
InnerMostDimExclusiveScan
<
T
><<<
grid
,
block
,
share_mem_size
,
dev_ctx
.
stream
()
>>>
(
in_data
,
out_data
,
sum_data
,
inner_dim_size
,
outer_dim_size
,
scan_dim_size
,
two_power
,
reverse
);
// for large scan array we need to do add for correct result
int
element_size
=
element_per_block
*
2
;
if
(
scan_dim_size
>
element_size
)
{
dim3
sum_block
(
element_per_block
*
2
);
dim3
sum_grid
((
scan_dim_size
-
element_size
+
block
.
x
-
1
)
/
block
.
x
,
outer_dim_size
);
int
sum_size
=
((
scan_dim_size
+
1
)
/
2
+
block
.
x
-
1
)
/
block
.
x
;
AddBlockScan
<
T
><<<
sum_grid
,
sum_block
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
sum_data
,
element_size
,
scan_dim_size
,
sum_size
,
reverse
);
}
Tensor
tmp
;
tmp
.
Resize
(
out_dims
);
auto
*
tmp_data
=
tmp
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
next_in_data
=
out_data
;
T
*
next_out_data
=
tmp_data
;
if
(
transpose
)
{
MatrixTranspose
<
T
,
32
,
8
><<<
transpose_grids
,
blocks
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
in_data
,
height
,
width
);
next_in_data
=
out_data
;
next_out_data
=
tmp_data
;
}
auto
swap_ptr
=
[](
T
*&
ptr1
,
T
*&
ptr2
)
{
T
*
tmp
=
ptr2
;
ptr2
=
ptr1
;
ptr1
=
tmp
;
};
int
outer_size
=
height
/
scan_size
;
int
inner_size
=
width
;
// Consider the size of shared memory, here block size is 128
dim3
scan_grid
(
outer_size
,
inner_size
);
dim3
reverse_grid
=
scan_grid
;
if
(
reverse
)
{
if
(
transpose
)
{
reverse_grid
.
x
=
scan_grid
.
y
;
reverse_grid
.
y
=
scan_grid
.
x
;
MatrixRowReverse
<
T
><<<
reverse_grid
,
1024
,
0
,
dev_ctx
.
stream
()
>>>
(
next_in_data
,
next_out_data
,
scan_size
,
outer_size
,
inner_size
);
if
(
!
transpose
)
next_in_data
=
tmp_data
;
swap_ptr
(
next_in_data
,
next_out_data
);
}
else
{
dim3
block
(
32
,
16
);
dim3
grid
((
outer_dim_size
+
block
.
y
-
1
)
/
block
.
y
);
InnerMostDimInclusiveScan
<
T
,
32
,
16
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
out_data
,
inner_dim_size
,
outer_dim_size
,
scan_dim_size
,
reverse
);
MatrixRowReverse
<
T
><<<
reverse_grid
,
1024
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
out_data
,
scan_size
,
outer_size
,
inner_size
);
}
}
if
(
!
transpose
&&
!
reverse
)
{
BlockScanKernel
<
T
,
128
,
4
><<<
scan_grid
,
128
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
in_data
,
outer_size
,
inner_size
,
scan_size
,
exclusive
);
}
else
{
dim3
block
(
std
::
min
(
512
,
inner_dim_size
));
dim3
grid
(
outer_dim_size
,
(
inner_dim_size
+
block
.
x
-
1
)
/
block
.
x
);
OuterScan
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
out_data
,
inner_dim_size
,
outer_dim_size
,
scan_dim_size
,
exclusive
,
reverse
);
BlockScanKernel
<
T
,
128
,
4
><<<
scan_grid
,
128
,
0
,
dev_ctx
.
stream
()
>>>
(
next_out_data
,
next_in_data
,
outer_size
,
inner_size
,
scan_size
,
exclusive
);
}
swap_ptr
(
next_in_data
,
next_out_data
);
if
(
reverse
)
{
MatrixRowReverse
<
T
><<<
reverse_grid
,
1024
,
0
,
dev_ctx
.
stream
()
>>>
(
next_in_data
,
next_out_data
,
scan_size
,
outer_size
,
inner_size
);
swap_ptr
(
next_in_data
,
next_out_data
);
}
if
(
transpose
)
{
transpose_grids
.
x
=
(
height
+
tile_size
-
1
)
/
tile_size
;
transpose_grids
.
y
=
(
width
+
tile_size
-
1
)
/
tile_size
;
MatrixTranspose
<
T
,
32
,
8
><<<
transpose_grids
,
blocks
,
0
,
dev_ctx
.
stream
()
>>>
(
next_out_data
,
next_in_data
,
width
,
height
);
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录