Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b818429a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b818429a
编写于
11月 28, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
11月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize cumsum OP (#29193)
上级
27b42183
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
196 addition
and
253 deletion
+196
-253
paddle/fluid/operators/cumsum_op.cu
paddle/fluid/operators/cumsum_op.cu
+196
-253
未找到文件。
paddle/fluid/operators/cumsum_op.cu
浏览文件 @
b818429a
...
...
@@ -14,8 +14,10 @@ limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/gather.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/cum_op.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
...
...
@@ -25,223 +27,157 @@ using LoDTensor = paddle::framework::LoDTensor;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
OuterScan
(
const
T
*
in
,
T
*
out
,
int
inner_dim_size
,
int
outer_dim_size
,
int
scan_dim_size
,
bool
exclusive
,
bool
reverse
)
{
int
id
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
outer_index
=
blockIdx
.
x
;
outer_index
<
outer_dim_size
;
outer_index
+=
gridDim
.
x
)
{
for
(
int
inner_index
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
inner_index
<
inner_dim_size
;
inner_index
+=
gridDim
.
y
*
blockDim
.
x
)
{
int
scan_index_init
=
0
;
int
forward_direction
=
1
;
int
src_index
=
outer_index
*
scan_dim_size
*
inner_dim_size
+
inner_index
;
int
dst_index
=
outer_index
*
scan_dim_size
*
inner_dim_size
+
inner_index
;
if
(
reverse
)
{
src_index
=
src_index
+
(
scan_dim_size
-
1
)
*
inner_dim_size
;
dst_index
=
dst_index
+
(
scan_dim_size
-
1
)
*
inner_dim_size
;
forward_direction
=
-
1
;
}
if
(
exclusive
)
{
scan_index_init
=
1
;
out
[
dst_index
]
=
0
;
dst_index
=
dst_index
+
(
forward_direction
*
inner_dim_size
);
}
T
acc
=
0
;
for
(
int
scan_index
=
scan_index_init
;
scan_index
<
scan_dim_size
;
++
scan_index
)
{
acc
=
in
[
src_index
]
+
acc
;
out
[
dst_index
]
=
acc
;
src_index
+=
(
forward_direction
*
inner_dim_size
);
dst_index
+=
(
forward_direction
*
inner_dim_size
);
}
}
template
<
typename
T
,
int
BLOCK_SIZE
>
__device__
void
BlockReverse
(
const
T
*
idata
,
T
*
odata
,
int
src_base
,
int
dst_base
,
int
valid_item
)
{
__shared__
T
sh_mem
[
BLOCK_SIZE
];
int
tx
=
threadIdx
.
x
;
int
offset
=
tx
;
int
in_index
=
src_base
+
offset
;
if
(
offset
>=
valid_item
)
{
sh_mem
[
offset
]
=
0
;
}
else
{
int
sh_mem_index
=
BLOCK_SIZE
-
offset
-
1
;
T
data
=
idata
[
in_index
];
sh_mem
[
sh_mem_index
]
=
data
;
}
__syncthreads
();
int
out_index
=
dst_base
-
offset
;
if
(
offset
<
valid_item
)
{
int
sh_mem_index
=
BLOCK_SIZE
-
offset
-
1
;
odata
[
out_index
]
=
sh_mem
[
sh_mem_index
];
}
}
// inclusive scan
template
<
typename
T
,
int
num_threads_x
,
int
num_threads_y
>
__global__
void
InnerMostDimInclusiveScan
(
const
T
*
in
,
T
*
out
,
int
inner_dim_size
,
int
outer_dim_size
,
int
scan_dim_size
,
bool
reverse
)
{
__shared__
T
share_data
[
num_threads_y
][
num_threads_x
*
2
];
T
*
share_row
=
share_data
[
threadIdx
.
y
];
int
forward_direction
=
1
;
if
(
reverse
)
forward_direction
=
-
1
;
for
(
int
block_row
=
blockIdx
.
x
*
blockDim
.
y
;
block_row
<
outer_dim_size
;
block_row
+=
blockDim
.
y
*
gridDim
.
x
)
{
int
row
=
block_row
+
threadIdx
.
y
;
T
acc
=
0
;
const
T
*
row_src
=
in
+
row
*
scan_dim_size
;
T
*
row_dst
=
out
+
row
*
scan_dim_size
;
int
block_col
=
0
;
bool
loop_condition
=
(
block_col
<
scan_dim_size
);
if
(
reverse
)
{
loop_condition
=
(
block_col
>=
0
);
block_col
=
scan_dim_size
-
1
;
template
<
typename
T
>
__global__
void
MatrixRowReverse
(
const
T
*
matrix_data
,
T
*
reverse_data
,
int
reverse_size
,
int
outer_size
,
int
inner_size
)
{
int
bx
=
blockIdx
.
x
;
int
by
=
blockIdx
.
y
;
int
item_per_block
=
1024
;
for
(
int
block_offset
=
0
;
block_offset
<
reverse_size
;
block_offset
+=
item_per_block
)
{
int
valid_item
=
(
reverse_size
-
block_offset
>
item_per_block
)
?
item_per_block
:
reverse_size
-
block_offset
;
int
src_offset
=
bx
*
reverse_size
+
block_offset
+
by
*
(
inner_size
*
reverse_size
);
int
dst_offset
=
bx
*
reverse_size
+
by
*
(
inner_size
*
reverse_size
)
+
reverse_size
-
1
-
block_offset
;
if
(
reverse_size
<
item_per_block
)
{
valid_item
=
reverse_size
;
}
while
(
loop_condition
)
{
// Load data into share memory(two value per thread)
int
col1
=
block_col
+
threadIdx
.
x
*
forward_direction
;
int
col2
=
block_col
+
(
num_threads_x
+
threadIdx
.
x
)
*
forward_direction
;
if
(
row
<
outer_dim_size
)
{
if
(
col1
<
scan_dim_size
&&
col1
>=
0
)
{
share_row
[
threadIdx
.
x
]
=
row_src
[
col1
];
}
else
{
share_row
[
threadIdx
.
x
]
=
0
;
}
if
(
col2
<
scan_dim_size
&&
col2
>=
0
)
{
share_row
[
num_threads_x
+
threadIdx
.
x
]
=
row_src
[
col2
];
}
else
{
share_row
[
num_threads_x
+
threadIdx
.
x
]
=
0
;
}
// Add the previous block acc to the result
if
(
threadIdx
.
x
==
0
)
{
share_row
[
0
]
=
share_row
[
0
]
+
acc
;
}
}
__syncthreads
();
// Up-Sweep
for
(
unsigned
s
=
num_threads_x
,
d
=
1
;
s
>=
1
;
s
>>=
1
,
d
<<=
1
)
{
if
(
row
<
outer_dim_size
&&
threadIdx
.
x
<
s
)
{
unsigned
offset
=
(
2
*
threadIdx
.
x
+
1
)
*
d
-
1
;
share_row
[
offset
+
d
]
=
share_row
[
offset
]
+
share_row
[
offset
+
d
];
}
__syncthreads
();
}
// Down-Sweep
for
(
unsigned
s
=
2
,
d
=
blockDim
.
x
/
2
;
d
>=
1
;
s
<<=
1
,
d
>>=
1
)
{
if
(
row
<
outer_dim_size
&&
threadIdx
.
x
<
s
-
1
)
{
unsigned
offset
=
2
*
(
threadIdx
.
x
+
1
)
*
d
-
1
;
share_row
[
offset
+
d
]
=
share_row
[
offset
]
+
share_row
[
offset
+
d
];
}
__syncthreads
();
}
// Write to the output
if
(
row
<
outer_dim_size
)
{
if
(
col1
<
scan_dim_size
&&
col1
>=
0
)
row_dst
[
col1
]
=
share_row
[
threadIdx
.
x
];
if
(
col2
<
scan_dim_size
&&
col2
>=
0
)
row_dst
[
col2
]
=
share_row
[
num_threads_x
+
threadIdx
.
x
];
}
acc
=
share_row
[
2
*
num_threads_x
-
1
];
__syncthreads
();
block_col
+=
2
*
num_threads_x
*
forward_direction
;
if
(
reverse
)
loop_condition
=
(
block_col
>=
0
);
else
loop_condition
=
(
block_col
<
scan_dim_size
);
}
BlockReverse
<
T
,
1024
>
(
matrix_data
,
reverse_data
,
src_offset
,
dst_offset
,
valid_item
);
}
}
// exclusive block scan and store block sum for large scan
template
<
typename
T
>
__global__
void
InnerMostDimExclusiveScan
(
const
T
*
in
,
T
*
out
,
T
*
sum_data
,
int
inner_dim_size
,
int
outer_dim_size
,
int
scan_dim_size
,
int
two_power
,
bool
reverse
)
{
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern
__shared__
__align__
(
sizeof
(
T
))
unsigned
char
raw_tmp
[];
T
*
share_tmp
=
reinterpret_cast
<
T
*>
(
raw_tmp
);
int
thread_id
=
threadIdx
.
x
;
int
block_id
=
blockIdx
.
x
;
int
block_scan_size
=
blockDim
.
x
*
2
;
int
remain
=
scan_dim_size
%
(
2
*
blockDim
.
x
);
if
(
block_id
==
gridDim
.
x
-
1
&&
remain
!=
0
)
block_scan_size
=
remain
;
int
col1
=
thread_id
;
int
col2
=
thread_id
+
(
block_scan_size
)
/
2
;
int
index1
=
blockIdx
.
y
*
(
scan_dim_size
)
+
block_id
*
blockDim
.
x
*
2
+
col1
;
int
index2
=
blockIdx
.
y
*
(
scan_dim_size
)
+
block_id
*
blockDim
.
x
*
2
+
col2
;
if
(
reverse
)
{
index1
=
blockIdx
.
y
*
(
scan_dim_size
)
+
scan_dim_size
-
1
-
(
block_id
*
blockDim
.
x
*
2
+
col1
);
index2
=
blockIdx
.
y
*
(
scan_dim_size
)
+
scan_dim_size
-
1
-
(
block_id
*
blockDim
.
x
*
2
+
col2
);
}
int
sum_index
=
blockIdx
.
y
*
gridDim
.
x
+
block_id
;
if
(
thread_id
<
block_scan_size
)
{
share_tmp
[
col1
+
(
col1
>>
5
)]
=
in
[
index1
];
share_tmp
[
col2
+
(
col2
>>
5
)]
=
in
[
index2
];
}
else
{
share_tmp
[
col1
+
(
col1
>>
5
)]
=
0
;
share_tmp
[
col2
+
(
col2
>>
5
)]
=
0
;
struct
BlockPrefixCallbackOp
{
// Running prefix
T
running_total
;
// Constructor
__device__
BlockPrefixCallbackOp
(
T
running_total
)
:
running_total
(
running_total
)
{}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__
T
operator
()(
T
block_aggregate
)
{
T
old_prefix
=
running_total
;
running_total
=
old_prefix
+
block_aggregate
;
return
old_prefix
;
}
};
// Up-Sweep
int
offset
=
1
;
for
(
int
d
=
(
two_power
/
2
);
d
>
0
;
d
>>=
1
)
{
__syncthreads
();
if
(
thread_id
<
d
)
{
int
tmp_index1
=
offset
*
(
2
*
thread_id
+
1
)
-
1
;
int
tmp_index2
=
offset
*
(
2
*
thread_id
+
2
)
-
1
;
tmp_index1
=
tmp_index1
+
(
tmp_index1
>>
5
);
tmp_index2
=
tmp_index2
+
(
tmp_index2
>>
5
);
share_tmp
[
tmp_index2
]
+=
share_tmp
[
tmp_index1
];
// No bank-conflict transpose
// Same as transposeCoalesced except the first tile dimension is padded
// to avoid shared memory bank conflicts.
template
<
typename
T
,
int
TILE_DIM
,
int
BLOCK_ROWS
>
__global__
void
MatrixTranspose
(
T
*
odata
,
const
T
*
idata
,
size_t
height
,
size_t
width
)
{
__shared__
T
tile
[
TILE_DIM
][
TILE_DIM
+
1
];
int
x
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
x
;
int
y
=
blockIdx
.
y
*
TILE_DIM
+
threadIdx
.
y
;
for
(
int
j
=
0
;
j
<
TILE_DIM
;
j
+=
BLOCK_ROWS
)
{
if
(
x
<
width
&&
(
y
+
j
)
<
height
)
{
tile
[
threadIdx
.
y
+
j
][
threadIdx
.
x
]
=
idata
[(
y
+
j
)
*
width
+
x
];
}
else
{
tile
[
threadIdx
.
y
+
j
][
threadIdx
.
x
]
=
0
;
}
offset
*=
2
;
}
__syncthreads
();
if
(
thread_id
==
0
)
{
int
tmp_index
=
(
two_power
-
1
)
+
((
two_power
-
1
)
>>
5
);
sum_data
[
sum_index
]
=
share_tmp
[
tmp_index
];
share_tmp
[
tmp_index
]
=
0
;
}
x
=
blockIdx
.
y
*
TILE_DIM
+
threadIdx
.
x
;
// transpose block offset
y
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
// Down Sweep
for
(
int
d
=
1
;
d
<
two_power
;
d
*=
2
)
{
offset
>>=
1
;
__syncthreads
();
if
(
thread_id
<
d
)
{
int
tmp_index1
=
offset
*
(
2
*
thread_id
+
1
)
-
1
;
int
tmp_index2
=
offset
*
(
2
*
thread_id
+
2
)
-
1
;
tmp_index1
=
tmp_index1
+
(
tmp_index1
>>
5
);
tmp_index2
=
tmp_index2
+
(
tmp_index2
>>
5
);
T
tmp
=
share_tmp
[
tmp_index1
];
share_tmp
[
tmp_index1
]
=
share_tmp
[
tmp_index2
];
share_tmp
[
tmp_index2
]
+=
tmp
;
for
(
int
j
=
0
;
j
<
TILE_DIM
;
j
+=
BLOCK_ROWS
)
{
if
(
x
<
height
&&
(
y
+
j
)
<
width
)
{
odata
[(
y
+
j
)
*
height
+
x
]
=
tile
[
threadIdx
.
x
][
threadIdx
.
y
+
j
];
}
}
}
__syncthreads
();
template
<
typename
T
,
int
BLOCK_THREADS
,
int
ITEMS_PER_THREAD
>
__global__
void
BlockScanKernel
(
T
*
d_out
,
const
T
*
d_in
,
int
inner_size
,
int
outer_size
,
int
scan_size
,
bool
exclusive
)
{
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef
cub
::
BlockLoad
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_LOAD_TRANSPOSE
>
BlockLoadT
;
typedef
cub
::
BlockStore
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_STORE_TRANSPOSE
>
BlockStoreT
;
typedef
cub
::
BlockScan
<
T
,
BLOCK_THREADS
>
BlockScanT
;
// Allocate type-safe, repurposable shared memory for collectives
__shared__
union
{
typename
BlockLoadT
::
TempStorage
load
;
typename
BlockStoreT
::
TempStorage
store
;
typename
BlockScanT
::
TempStorage
scan
;
}
temp_storage
;
int
bx
=
blockIdx
.
x
;
int
by
=
blockIdx
.
y
;
BlockPrefixCallbackOp
<
T
>
prefix_op
(
0
);
T
block_aggregate
=
static_cast
<
T
>
(
0
);
// Obtain this block's segment of consecutive keys (blocked across threads)
int
item_per_block
=
BLOCK_THREADS
*
ITEMS_PER_THREAD
;
for
(
int
block_offset
=
0
;
block_offset
<
scan_size
;
block_offset
+=
BLOCK_THREADS
*
ITEMS_PER_THREAD
)
{
int
valid_item
=
(
scan_size
-
block_offset
>
item_per_block
)
?
item_per_block
:
(
scan_size
-
block_offset
);
if
(
scan_size
<
item_per_block
)
{
valid_item
=
scan_size
;
}
if
(
col1
<
block_scan_size
)
out
[
index1
]
=
share_tmp
[
col1
+
(
col1
>>
5
)];
if
(
col2
<
block_scan_size
)
out
[
index2
]
=
share_tmp
[
col2
+
(
col2
>>
5
)];
}
int
offset
=
bx
*
scan_size
+
block_offset
+
by
*
(
inner_size
*
scan_size
);
// for large scan_dim_size array we need to add for correct result
template
<
typename
T
>
__global__
void
AddBlockScan
(
T
*
result
,
T
*
sum
,
int
size
,
int
scan_dim_size
,
int
sum_size
,
bool
reverse
)
{
int
idx
=
threadIdx
.
x
+
blockDim
.
x
*
(
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
);
int
block_id_start
=
blockIdx
.
y
*
sum_size
;
int
block_id_end
=
blockIdx
.
x
+
blockIdx
.
y
*
sum_size
;
int
block_id
=
blockIdx
.
x
;
int
thread_id
=
threadIdx
.
x
;
int
col
=
block_id
*
blockDim
.
x
+
thread_id
+
size
;
int
index
=
blockIdx
.
y
*
(
scan_dim_size
)
+
col
;
if
(
reverse
)
{
index
=
blockIdx
.
y
*
(
scan_dim_size
)
+
scan_dim_size
-
1
-
col
;
}
T
thread_keys
[
ITEMS_PER_THREAD
];
BlockLoadT
(
temp_storage
.
load
)
.
Load
(
d_in
+
offset
,
thread_keys
,
valid_item
,
0
);
if
(
col
>=
scan_dim_size
||
col
<
0
)
return
;
for
(
int
i
=
block_id_start
;
i
<=
block_id_end
;
i
++
)
{
result
[
index
]
+=
sum
[
i
];
__syncthreads
();
if
(
exclusive
)
{
T
init_value
=
static_cast
<
T
>
(
0
);
BlockScanT
(
temp_storage
.
scan
)
.
ExclusiveScan
(
thread_keys
,
thread_keys
,
cub
::
Sum
(),
prefix_op
);
}
else
{
BlockScanT
(
temp_storage
.
scan
)
.
InclusiveScan
(
thread_keys
,
thread_keys
,
cub
::
Sum
(),
prefix_op
);
}
__syncthreads
();
BlockStoreT
(
temp_storage
.
store
)
.
Store
(
d_out
+
offset
,
thread_keys
,
valid_item
);
}
}
...
...
@@ -298,72 +234,79 @@ class CumCUDAKernel : public framework::OpKernel<T> {
return
;
}
const
int
&
scan_dim_size
=
out_dims
[
axis
];
bool
optimize_condition
=
(
axis
==
(
out_dims
.
size
()
-
1
))
?
true
:
false
;
int
outer_dim_size
=
1
;
int
inner_dim_size
=
1
;
// treat all dim index < axis as outer_dim_size
for
(
size_t
i
=
0
;
i
<
axis
;
i
++
)
{
outer_dim_size
*=
out_dims
[
i
];
size_t
height
=
1
;
size_t
width
=
1
;
for
(
size_t
i
=
0
;
i
<=
axis
;
i
++
)
{
height
*=
out_dims
[
i
];
}
// treat all dim index > axis as innner_dim_size
for
(
size_t
i
=
axis
+
1
;
i
<
out_dims
.
size
();
i
++
)
{
inner_dim_size
*=
out_dims
[
i
];
width
*=
out_dims
[
i
];
}
int
scan_size
=
out_dims
[
axis
];
bool
transpose
=
(
axis
!=
out_dims
.
size
()
-
1
);
int
tile_size
=
32
;
dim3
blocks
(
32
,
8
);
dim3
transpose_grids
((
width
+
tile_size
-
1
)
/
tile_size
,
(
height
+
tile_size
-
1
)
/
tile_size
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
if
(
optimize_condition
)
{
auto
nextPowerOfTwo
=
[](
int
x
)
->
int
{
int
ret
=
1
;
while
(
ret
<
x
)
ret
=
ret
*
2
;
return
ret
;
};
if
(
exclusive
)
{
int
element_per_block
=
nextPowerOfTwo
(
scan_dim_size
)
/
2
;
if
(
element_per_block
>
512
||
element_per_block
<
32
)
{
element_per_block
=
64
;
}
int
two_power
=
element_per_block
*
2
;
dim3
block
(
element_per_block
);
dim3
grid
(((
scan_dim_size
+
1
)
/
2
+
block
.
x
-
1
)
/
block
.
x
,
outer_dim_size
);
int
offset_size
=
(
element_per_block
*
2
)
>>
5
;
int
share_mem_size
=
(
element_per_block
*
2
+
offset_size
)
*
sizeof
(
T
);
Tensor
scan_sum
;
paddle
::
framework
::
DDim
dims
{
((
scan_dim_size
+
1
)
/
2
+
block
.
x
-
1
)
/
block
.
x
,
outer_dim_size
};
scan_sum
.
Resize
(
dims
);
T
*
sum_data
=
scan_sum
.
mutable_data
<
T
>
(
context
.
GetPlace
());
InnerMostDimExclusiveScan
<
T
><<<
grid
,
block
,
share_mem_size
,
dev_ctx
.
stream
()
>>>
(
in_data
,
out_data
,
sum_data
,
inner_dim_size
,
outer_dim_size
,
scan_dim_size
,
two_power
,
reverse
);
// for large scan array we need to do add for correct result
int
element_size
=
element_per_block
*
2
;
if
(
scan_dim_size
>
element_size
)
{
dim3
sum_block
(
element_per_block
*
2
);
dim3
sum_grid
((
scan_dim_size
-
element_size
+
block
.
x
-
1
)
/
block
.
x
,
outer_dim_size
);
int
sum_size
=
((
scan_dim_size
+
1
)
/
2
+
block
.
x
-
1
)
/
block
.
x
;
AddBlockScan
<
T
><<<
sum_grid
,
sum_block
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
sum_data
,
element_size
,
scan_dim_size
,
sum_size
,
reverse
);
}
Tensor
tmp
;
tmp
.
Resize
(
out_dims
);
auto
*
tmp_data
=
tmp
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
next_in_data
=
out_data
;
T
*
next_out_data
=
tmp_data
;
if
(
transpose
)
{
MatrixTranspose
<
T
,
32
,
8
><<<
transpose_grids
,
blocks
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
in_data
,
height
,
width
);
next_in_data
=
out_data
;
next_out_data
=
tmp_data
;
}
auto
swap_ptr
=
[](
T
*&
ptr1
,
T
*&
ptr2
)
{
T
*
tmp
=
ptr2
;
ptr2
=
ptr1
;
ptr1
=
tmp
;
};
int
outer_size
=
height
/
scan_size
;
int
inner_size
=
width
;
// Consider the size of shared memory, here block size is 128
dim3
scan_grid
(
outer_size
,
inner_size
);
dim3
reverse_grid
=
scan_grid
;
if
(
reverse
)
{
if
(
transpose
)
{
reverse_grid
.
x
=
scan_grid
.
y
;
reverse_grid
.
y
=
scan_grid
.
x
;
MatrixRowReverse
<
T
><<<
reverse_grid
,
1024
,
0
,
dev_ctx
.
stream
()
>>>
(
next_in_data
,
next_out_data
,
scan_size
,
outer_size
,
inner_size
);
if
(
!
transpose
)
next_in_data
=
tmp_data
;
swap_ptr
(
next_in_data
,
next_out_data
);
}
else
{
dim3
block
(
32
,
16
);
dim3
grid
((
outer_dim_size
+
block
.
y
-
1
)
/
block
.
y
);
InnerMostDimInclusiveScan
<
T
,
32
,
16
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
out_data
,
inner_dim_size
,
outer_dim_size
,
scan_dim_size
,
reverse
);
MatrixRowReverse
<
T
><<<
reverse_grid
,
1024
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
out_data
,
scan_size
,
outer_size
,
inner_size
);
}
}
if
(
!
transpose
&&
!
reverse
)
{
BlockScanKernel
<
T
,
128
,
4
><<<
scan_grid
,
128
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
in_data
,
outer_size
,
inner_size
,
scan_size
,
exclusive
);
}
else
{
dim3
block
(
std
::
min
(
512
,
inner_dim_size
));
dim3
grid
(
outer_dim_size
,
(
inner_dim_size
+
block
.
x
-
1
)
/
block
.
x
);
OuterScan
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
out_data
,
inner_dim_size
,
outer_dim_size
,
scan_dim_size
,
exclusive
,
reverse
);
BlockScanKernel
<
T
,
128
,
4
><<<
scan_grid
,
128
,
0
,
dev_ctx
.
stream
()
>>>
(
next_out_data
,
next_in_data
,
outer_size
,
inner_size
,
scan_size
,
exclusive
);
}
swap_ptr
(
next_in_data
,
next_out_data
);
if
(
reverse
)
{
MatrixRowReverse
<
T
><<<
reverse_grid
,
1024
,
0
,
dev_ctx
.
stream
()
>>>
(
next_in_data
,
next_out_data
,
scan_size
,
outer_size
,
inner_size
);
swap_ptr
(
next_in_data
,
next_out_data
);
}
if
(
transpose
)
{
transpose_grids
.
x
=
(
height
+
tile_size
-
1
)
/
tile_size
;
transpose_grids
.
y
=
(
width
+
tile_size
-
1
)
/
tile_size
;
MatrixTranspose
<
T
,
32
,
8
><<<
transpose_grids
,
blocks
,
0
,
dev_ctx
.
stream
()
>>>
(
next_out_data
,
next_in_data
,
width
,
height
);
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录