Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0de94cd9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“47aea0cdf8bb2487be3efd89e42d71bf81d30f18”上不存在“release/0.10.0/doc/howto/usage/k8s/k8s_aws_en.html”
未验证
提交
0de94cd9
编写于
1月 03, 2023
作者:
L
limingshu
提交者:
GitHub
1月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
H2D data transfer optimization for concat kernel (#49040)
上级
f484a61e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
342 addition
and
279 deletion
+342
-279
paddle/phi/backends/gpu/gpu_launch_config.h
paddle/phi/backends/gpu/gpu_launch_config.h
+8
-3
paddle/phi/kernels/funcs/concat_and_split_functor.cu
paddle/phi/kernels/funcs/concat_and_split_functor.cu
+264
-226
paddle/phi/kernels/gpu/stack_kernel.cu
paddle/phi/kernels/gpu/stack_kernel.cu
+70
-50
未找到文件。
paddle/phi/backends/gpu/gpu_launch_config.h
浏览文件 @
0de94cd9
...
@@ -53,20 +53,25 @@ inline T DivUp(T a, T b) {
...
@@ -53,20 +53,25 @@ inline T DivUp(T a, T b) {
// https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
// https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
// for round integer value into next highest power of 2.
// for round integer value into next highest power of 2.
inline
int64_t
RoundTo
PowerOfTwo
(
int64_t
n
)
{
inline
int64_t
RoundTo
NextHighPowOfTwo
(
int64_t
n
,
int64_t
min_val
=
1
)
{
n
--
;
n
--
;
n
|=
(
n
>>
1
);
n
|=
(
n
>>
1
);
n
|=
(
n
>>
2
);
n
|=
(
n
>>
2
);
n
|=
(
n
>>
4
);
n
|=
(
n
>>
4
);
n
|=
(
n
>>
8
);
n
|=
(
n
>>
8
);
n
|=
(
n
>>
16
);
n
|=
(
n
>>
16
);
int64_t
min_val
=
32
;
return
std
::
max
(
min_val
,
(
n
+
1
));
}
inline
int64_t
RoundToPowerOfTwo
(
int64_t
n
)
{
constexpr
int64_t
min_val
=
32
;
int64_t
num
=
RoundToNextHighPowOfTwo
(
n
,
min_val
);
#ifdef __HIPCC__
#ifdef __HIPCC__
int64_t
max_val
=
256
;
int64_t
max_val
=
256
;
#else
#else
int64_t
max_val
=
1024
;
int64_t
max_val
=
1024
;
#endif
#endif
return
std
::
min
(
max_val
,
std
::
max
(
min_val
,
(
n
+
1
))
);
return
std
::
min
(
max_val
,
num
);
}
}
#ifdef WITH_NV_JETSON
#ifdef WITH_NV_JETSON
...
...
paddle/phi/kernels/funcs/concat_and_split_functor.cu
浏览文件 @
0de94cd9
...
@@ -15,49 +15,155 @@ limitations under the License. */
...
@@ -15,49 +15,155 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
namespace
phi
{
namespace
phi
{
namespace
funcs
{
namespace
funcs
{
template
<
typename
T
,
int
Size
>
struct
PointerWrapper
{
public:
const
T
*
ins_addr
[
Size
];
__device__
inline
const
T
*
operator
[](
int
i
)
const
{
return
ins_addr
[
i
];
}
PointerWrapper
()
{}
PointerWrapper
(
const
phi
::
GPUContext
&
ctx
,
const
std
::
vector
<
phi
::
DenseTensor
>&
ins
,
const
T
**
pre_alloced_host_ptr
)
{
for
(
auto
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
ins_addr
[
i
]
=
ins
[
i
].
data
<
T
>
();
}
}
};
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ConcatKernel_
(
const
T
**
inputs
,
struct
PointerToPointer
{
const
int64_t
*
input_cols
,
public:
int
col_size
,
T
**
ins_addr
{
nullptr
};
const
int64_t
output_rows
,
__device__
inline
const
T
*
operator
[](
int
i
)
const
{
return
ins_addr
[
i
];
}
const
int64_t
output_cols
,
T
*
output
)
{
PointerToPointer
()
{}
int64_t
curr_segment
=
0
;
PointerToPointer
(
const
phi
::
GPUContext
&
ctx
,
int64_t
curr_offset
=
input_cols
[
0
];
const
std
::
vector
<
phi
::
DenseTensor
>&
ins
,
CUDA_KERNEL_LOOP_TYPE
(
tid_x
,
output_cols
,
int64_t
)
{
const
T
**
pre_alloced_host_ptr
,
int64_t
curr_col_offset
=
input_cols
[
curr_segment
+
1
];
paddle
::
memory
::
AllocationPtr
*
dev_ins_ptr
)
{
auto
in_num
=
ins
.
size
();
for
(
auto
i
=
0
;
i
<
in_num
;
++
i
)
{
pre_alloced_host_ptr
[
i
]
=
ins
[
i
].
data
<
T
>
();
}
*
dev_ins_ptr
=
paddle
::
memory
::
Alloc
(
ctx
.
GetPlace
(),
in_num
*
sizeof
(
T
*
),
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
ctx
.
stream
())));
auto
*
restored
=
phi
::
backends
::
gpu
::
RestoreHostMemIfCapturingCUDAGraph
(
pre_alloced_host_ptr
,
in_num
);
paddle
::
memory
::
Copy
(
ctx
.
GetPlace
(),
(
*
dev_ins_ptr
)
->
ptr
(),
phi
::
CPUPlace
(),
restored
,
in_num
*
sizeof
(
T
*
),
ctx
.
stream
());
ins_addr
=
reinterpret_cast
<
T
**>
((
*
dev_ins_ptr
)
->
ptr
());
}
};
template
<
typename
T
,
typename
IndexT
,
int
Size
>
struct
PointerAndColWrapper
{
public:
IndexT
col_length
[
Size
];
PointerAndColWrapper
(
const
phi
::
GPUContext
&
ctx
,
const
std
::
vector
<
phi
::
DenseTensor
>&
ins
,
const
IndexT
&
inputs_col_num
,
const
T
**
pre_alloced_host_ptr
,
IndexT
*
inputs_col
)
{
for
(
auto
i
=
0
;
i
<
inputs_col_num
;
++
i
)
{
col_length
[
i
]
=
inputs_col
[
i
];
}
ins_ptr_wrapper
=
PointerWrapper
<
T
,
Size
>
(
ctx
,
ins
,
pre_alloced_host_ptr
);
}
__device__
inline
const
T
*
operator
[](
int
i
)
const
{
return
ins_ptr_wrapper
[
i
];
}
private:
PointerWrapper
<
T
,
Size
>
ins_ptr_wrapper
;
};
template
<
typename
T
,
typename
IndexT
>
struct
PointerToPointerAndCol
{
public:
IndexT
*
col_length
{
nullptr
};
PointerToPointerAndCol
(
const
phi
::
GPUContext
&
ctx
,
const
std
::
vector
<
phi
::
DenseTensor
>&
ins
,
const
IndexT
inputs_col_num
,
const
T
**
pre_alloced_host_ptr
,
IndexT
*
inputs_col
,
paddle
::
memory
::
AllocationPtr
*
dev_ins_ptr
,
paddle
::
memory
::
AllocationPtr
*
dev_col_ptr
)
{
*
dev_col_ptr
=
paddle
::
memory
::
Alloc
(
ctx
.
GetPlace
(),
inputs_col_num
*
sizeof
(
IndexT
),
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
ctx
.
stream
())));
auto
*
restored
=
phi
::
backends
::
gpu
::
RestoreHostMemIfCapturingCUDAGraph
(
inputs_col
,
inputs_col_num
);
paddle
::
memory
::
Copy
(
ctx
.
GetPlace
(),
(
*
dev_col_ptr
)
->
ptr
(),
phi
::
CPUPlace
(),
restored
,
inputs_col_num
*
sizeof
(
IndexT
),
ctx
.
stream
());
col_length
=
static_cast
<
IndexT
*>
((
*
dev_col_ptr
)
->
ptr
());
ins_ptr_wrapper
=
PointerToPointer
<
T
>
(
ctx
,
ins
,
pre_alloced_host_ptr
,
dev_ins_ptr
);
}
__device__
inline
const
T
*
operator
[](
int
i
)
const
{
return
ins_ptr_wrapper
[
i
];
}
private:
PointerToPointer
<
T
>
ins_ptr_wrapper
;
};
template
<
typename
T
,
typename
IndexT
,
typename
PointerAndColWrapperT
>
__global__
void
ConcatTensorWithDifferentShape
(
PointerAndColWrapperT
ins_datas
,
int
col_size
,
const
IndexT
output_rows
,
const
IndexT
output_cols
,
T
*
output
)
{
IndexT
curr_segment
=
0
;
IndexT
curr_offset
=
ins_datas
.
col_length
[
0
];
CUDA_KERNEL_LOOP_TYPE
(
tid_x
,
output_cols
,
IndexT
)
{
IndexT
curr_col_offset
=
ins_datas
.
col_length
[
curr_segment
+
1
];
while
(
curr_col_offset
<=
tid_x
)
{
while
(
curr_col_offset
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
curr_offset
=
curr_col_offset
;
++
curr_segment
;
++
curr_segment
;
curr_col_offset
=
in
put_cols
[
curr_segment
+
1
];
curr_col_offset
=
in
s_datas
.
col_length
[
curr_segment
+
1
];
}
}
int64_t
local_col
=
tid_x
-
curr_offset
;
IndexT
local_col
=
tid_x
-
curr_offset
;
int64_t
segment_width
=
curr_col_offset
-
curr_offset
;
IndexT
segment_width
=
curr_col_offset
-
curr_offset
;
const
T
*
input_ptr
=
in
put
s
[
curr_segment
];
const
T
*
input_ptr
=
in
s_data
s
[
curr_segment
];
int64_t
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
IndexT
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output
[
tid_y
*
output_cols
+
tid_x
]
=
output
[
tid_y
*
output_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
segment_width
+
local_col
];
input_ptr
[
tid_y
*
segment_width
+
local_col
];
}
}
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
IndexT
,
typename
PointerWrapperT
>
__
device__
void
ConcatKernelDetail
(
const
T
**
input
s_data
,
__
global__
void
ConcatTensorWithSameShape
(
PointerWrapperT
in
s_data
,
const
int64_t
fixed_in_col
,
const
IndexT
fixed_in_col
,
const
int64_t
out_rows
,
const
IndexT
out_rows
,
const
int64_t
out_cols
,
const
IndexT
out_cols
,
T
*
output_data
)
{
T
*
output_data
)
{
CUDA_KERNEL_LOOP_TYPE
(
tid_x
,
out_cols
,
int64_t
)
{
CUDA_KERNEL_LOOP_TYPE
(
tid_x
,
out_cols
,
IndexT
)
{
int64_t
split
=
tid_x
*
1.0
/
fixed_in_col
;
IndexT
split
=
tid_x
/
fixed_in_col
;
int64_t
in_offset
=
tid_x
-
split
*
fixed_in_col
;
IndexT
in_offset
=
tid_x
-
split
*
fixed_in_col
;
const
T
*
input_ptr
=
in
put
s_data
[
split
];
const
T
*
input_ptr
=
ins_data
[
split
];
int64_t
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
IndexT
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
out_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
for
(;
tid_y
<
out_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
output_data
[
tid_y
*
out_cols
+
tid_x
]
=
output_data
[
tid_y
*
out_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
fixed_in_col
+
in_offset
];
input_ptr
[
tid_y
*
fixed_in_col
+
in_offset
];
...
@@ -65,65 +171,6 @@ __device__ void ConcatKernelDetail(const T** inputs_data,
...
@@ -65,65 +171,6 @@ __device__ void ConcatKernelDetail(const T** inputs_data,
}
}
}
}
template
<
typename
T
>
__global__
void
ConcatKernel_
(
const
T
*
input_addr0
,
const
T
*
input_addr1
,
const
int64_t
fixed_in_col
,
const
int64_t
out_rows
,
const
int64_t
out_cols
,
T
*
output_data
)
{
const
T
*
inputs_data
[
2
];
inputs_data
[
0
]
=
input_addr0
;
inputs_data
[
1
]
=
input_addr1
;
ConcatKernelDetail
<
T
>
(
inputs_data
,
fixed_in_col
,
out_rows
,
out_cols
,
output_data
);
}
template
<
typename
T
>
__global__
void
ConcatKernel_
(
const
T
*
input_addr0
,
const
T
*
input_addr1
,
const
T
*
input_addr2
,
const
int64_t
fixed_in_col
,
const
int64_t
out_rows
,
const
int64_t
out_cols
,
T
*
output_data
)
{
const
T
*
inputs_data
[
3
];
inputs_data
[
0
]
=
input_addr0
;
inputs_data
[
1
]
=
input_addr1
;
inputs_data
[
2
]
=
input_addr2
;
ConcatKernelDetail
<
T
>
(
inputs_data
,
fixed_in_col
,
out_rows
,
out_cols
,
output_data
);
}
template
<
typename
T
>
__global__
void
ConcatKernel_
(
const
T
*
input_addr0
,
const
T
*
input_addr1
,
const
T
*
input_addr2
,
const
T
*
input_addr3
,
const
int64_t
fixed_in_col
,
const
int64_t
out_rows
,
const
int64_t
out_cols
,
T
*
output_data
)
{
const
T
*
inputs_data
[
4
];
inputs_data
[
0
]
=
input_addr0
;
inputs_data
[
1
]
=
input_addr1
;
inputs_data
[
2
]
=
input_addr2
;
inputs_data
[
3
]
=
input_addr3
;
ConcatKernelDetail
<
T
>
(
inputs_data
,
fixed_in_col
,
out_rows
,
out_cols
,
output_data
);
}
template
<
typename
T
>
__global__
void
ConcatKernel_
(
const
T
**
inputs_data
,
const
int
in_num
,
const
int64_t
fixed_in_col
,
const
int64_t
out_rows
,
const
int64_t
out_cols
,
T
*
output_data
)
{
ConcatKernelDetail
<
T
>
(
inputs_data
,
fixed_in_col
,
out_rows
,
out_cols
,
output_data
);
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SplitKernel_
(
const
T
*
input_data
,
__global__
void
SplitKernel_
(
const
T
*
input_data
,
const
int64_t
in_row
,
const
int64_t
in_row
,
...
@@ -254,155 +301,146 @@ static inline void GetBlockDims(const phi::GPUContext& context,
...
@@ -254,155 +301,146 @@ static inline void GetBlockDims(const phi::GPUContext& context,
* All tensors' dimension should be the same and the values of
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
* each dimension must be the same, except the axis dimension.
*/
*/
template
<
typename
T
,
typename
IndexT
>
template
<
typename
T
>
void
ConcatFunctorWithIndexType
(
const
phi
::
GPUContext
&
ctx
,
struct
ConcatFunctor
<
phi
::
GPUContext
,
T
>
{
const
std
::
vector
<
phi
::
DenseTensor
>&
ins
,
void
operator
()(
const
phi
::
GPUContext
&
context
,
int
axis
,
const
std
::
vector
<
phi
::
DenseTensor
>&
input
,
phi
::
DenseTensor
*
output
)
{
int
axis
,
// TODO(zcd): Add input data validity checking
phi
::
DenseTensor
*
output
)
{
IndexT
in_num
=
ins
.
size
();
// TODO(zcd): Add input data validity checking
IndexT
in_row
=
1
;
int64_t
in_num
=
input
.
size
();
auto
dim_0
=
ins
[
0
].
dims
();
int64_t
in_row
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
auto
dim_0
=
input
[
0
].
dims
();
in_row
*=
dim_0
[
i
];
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
}
in_row
*=
dim_0
[
i
];
IndexT
in_col
=
ins
[
0
].
numel
()
/
in_row
;
}
IndexT
out_row
=
in_row
,
out_col
=
0
;
int64_t
in_col
=
input
[
0
].
numel
()
/
in_row
;
int64_t
out_row
=
in_row
,
out_col
=
0
;
IndexT
inputs_col_num
=
in_num
+
1
;
std
::
vector
<
const
T
*>
inputs_data_vec
(
in_num
,
nullptr
);
int64_t
inputs_col_num
=
in_num
+
1
;
std
::
vector
<
IndexT
>
inputs_col_vec
(
inputs_col_num
,
0
);
std
::
vector
<
const
T
*>
inputs_data_vec
(
in_num
);
const
T
**
inputs_data
=
inputs_data_vec
.
data
();
std
::
vector
<
int64_t
>
inputs_col_vec
(
inputs_col_num
);
IndexT
*
inputs_col
=
inputs_col_vec
.
data
();
const
T
**
inputs_data
=
inputs_data_vec
.
data
();
int64_t
*
inputs_col
=
inputs_col_vec
.
data
();
// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
paddle
::
memory
::
AllocationPtr
data_alloc
,
col_alloc
;
// TODO(chentianyu03): try to find a method to remove the Alloc function
// TODO(chentianyu03): try to find a method to remove the Alloc function
paddle
::
memory
::
AllocationPtr
data_alloc
=
paddle
::
memory
::
Alloc
(
data_alloc
=
paddle
::
memory
::
Alloc
(
paddle
::
platform
::
CUDAPinnedPlace
(),
paddle
::
platform
::
CUDAPinnedPlace
(),
in_num
*
sizeof
(
T
*
));
in_num
*
sizeof
(
T
*
));
inputs_data
=
reinterpret_cast
<
const
T
**>
(
data_alloc
->
ptr
());
inputs_data
=
reinterpret_cast
<
const
T
**>
(
data_alloc
->
ptr
());
paddle
::
memory
::
AllocationPtr
col_alloc
=
paddle
::
memory
::
Alloc
(
// TODO(chentianyu03): try to find a method to remove the Alloc function
paddle
::
platform
::
CUDAPinnedPlace
(),
inputs_col_num
*
sizeof
(
IndexT
));
col_alloc
=
paddle
::
memory
::
Alloc
(
paddle
::
platform
::
CUDAPinnedPlace
(),
inputs_col
=
reinterpret_cast
<
IndexT
*>
(
col_alloc
->
ptr
());
inputs_col_num
*
sizeof
(
int
));
inputs_col
=
reinterpret_cast
<
int64_t
*>
(
col_alloc
->
ptr
());
#endif
#endif
inputs_col
[
0
]
=
0
;
bool
has_same_shape
=
true
;
bool
has_same_shape
=
true
;
for
(
int
i
=
0
;
i
<
in_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
in_num
;
++
i
)
{
IndexT
t_cols
=
ins
[
i
].
numel
()
/
in_row
;
int64_t
t_cols
=
input
[
i
].
numel
()
/
in_row
;
if
(
has_same_shape
)
{
if
(
has_same_shape
)
{
has_same_shape
&=
(
t_cols
==
in_col
);
if
(
t_cols
!=
in_col
)
has_same_shape
=
false
;
}
out_col
+=
t_cols
;
inputs_col
[
i
+
1
]
=
out_col
;
inputs_data
[
i
]
=
input
[
i
].
data
<
T
>
();
}
}
out_col
+=
t_cols
;
dim3
block_dims
;
inputs_col
[
i
+
1
]
=
out_col
;
dim3
grid_dims
;
}
GetBlockDims
(
context
,
out_row
,
out_col
,
&
block_dims
,
&
grid_dims
);
dim3
block_dims
;
dim3
grid_dims
;
paddle
::
memory
::
allocation
::
AllocationPtr
tmp_dev_ins_data
;
GetBlockDims
(
ctx
,
out_row
,
out_col
,
&
block_dims
,
&
grid_dims
);
const
T
**
dev_ins_data
=
nullptr
;
IndexT
limit_num
=
has_same_shape
?
in_num
:
inputs_col_num
;
if
(
!
has_same_shape
||
in_num
<
2
||
in_num
>
4
)
{
tmp_dev_ins_data
=
paddle
::
memory
::
Alloc
(
#define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \
context
.
GetPlace
(),
func_impl(4, ##__VA_ARGS__); \
in_num
*
sizeof
(
T
*
),
func_impl(8, ##__VA_ARGS__); \
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
context
.
stream
())));
func_impl(16, ##__VA_ARGS__); \
auto
*
restored
=
phi
::
backends
::
gpu
::
RestoreHostMemIfCapturingCUDAGraph
(
func_impl(32, ##__VA_ARGS__); \
inputs_data
,
in_num
);
func_impl(64, ##__VA_ARGS__); \
paddle
::
memory
::
Copy
(
context
.
GetPlace
(),
func_impl(128, ##__VA_ARGS__);
tmp_dev_ins_data
->
ptr
(),
paddle
::
platform
::
CPUPlace
(),
if
(
has_same_shape
)
{
restored
,
#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \
in_num
*
sizeof
(
T
*
),
case size_: { \
context
.
stream
());
PointerWrapper<T, size_> ptr_array(ctx, ins, inputs_data); \
dev_ins_data
=
reinterpret_cast
<
const
T
**>
(
tmp_dev_ins_data
->
ptr
());
__VA_ARGS__; \
} break;
switch
(
phi
::
backends
::
gpu
::
RoundToNextHighPowOfTwo
(
limit_num
,
4
))
{
IMPL_CONCATE_CUDA_KERNEL_HELPER
(
IMPL_CONCAT_CUDA_KERNEL_CASE
,
ConcatTensorWithSameShape
<
T
,
IndexT
,
decltype
(
ptr_array
)
>
<<<
grid_dims
,
block_dims
,
0
,
ctx
.
stream
()
>>>
(
ptr_array
,
in_col
,
out_row
,
out_col
,
output
->
data
<
T
>
()));
default:
{
paddle
::
memory
::
AllocationPtr
dev_ins_ptr
{
nullptr
};
PointerToPointer
<
T
>
ptr_array
(
ctx
,
ins
,
inputs_data
,
&
dev_ins_ptr
);
ConcatTensorWithSameShape
<
T
,
IndexT
,
decltype
(
ptr_array
)
>
<<<
grid_dims
,
block_dims
,
0
,
ctx
.
stream
()
>>>
(
ptr_array
,
in_col
,
out_row
,
out_col
,
output
->
data
<
T
>
());
}
}
}
#undef IMPL_CONCAT_CUDA_KERNEL_CASE
if
(
has_same_shape
)
{
}
else
{
if
(
in_num
==
2
)
{
#define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...) \
ConcatKernel_
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
case size_: { \
inputs_data
[
0
],
PointerAndColWrapper<T, IndexT, size_> ptr_col_array( \
inputs_data
[
1
],
ctx, ins, inputs_col_num, inputs_data, inputs_col); \
in_col
,
__VA_ARGS__; \
out_row
,
} break;
out_col
,
output
->
data
<
T
>
());
switch
(
phi
::
backends
::
gpu
::
RoundToNextHighPowOfTwo
(
limit_num
,
4
))
{
}
else
if
(
in_num
==
3
)
{
IMPL_CONCATE_CUDA_KERNEL_HELPER
(
ConcatKernel_
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE
,
inputs_data
[
0
],
ConcatTensorWithDifferentShape
<
T
,
IndexT
,
decltype
(
ptr_col_array
)
>
inputs_data
[
1
],
<<<
grid_dims
,
block_dims
,
0
,
ctx
.
stream
()
>>>
(
ptr_col_array
,
inputs_data
[
2
],
inputs_col_num
,
in_col
,
out_row
,
out_row
,
out_col
,
out_col
,
output
->
data
<
T
>
()));
output
->
data
<
T
>
());
default:
{
}
else
if
(
in_num
==
4
)
{
paddle
::
memory
::
AllocationPtr
dev_ins_ptr
{
nullptr
};
ConcatKernel_
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
paddle
::
memory
::
AllocationPtr
dev_col_ptr
{
nullptr
};
inputs_data
[
0
],
PointerToPointerAndCol
<
T
,
IndexT
>
ptr_col_array
(
ctx
,
inputs_data
[
1
],
ins
,
inputs_data
[
2
],
inputs_col_num
,
inputs_data
[
3
],
inputs_data
,
in_col
,
inputs_col
,
out_row
,
&
dev_ins_ptr
,
out_col
,
&
dev_col_ptr
);
output
->
data
<
T
>
());
ConcatTensorWithDifferentShape
<
T
,
IndexT
,
decltype
(
ptr_col_array
)
>
}
else
{
<<<
grid_dims
,
block_dims
,
0
,
ctx
.
stream
()
>>>
(
ptr_col_array
,
ConcatKernel_
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
inputs_col_num
,
dev_ins_data
,
in_num
,
in_col
,
out_row
,
out_col
,
output
->
data
<
T
>
());
out_row
,
out_col
,
output
->
data
<
T
>
());
}
}
}
else
{
auto
tmp_dev_ins_col_data
=
paddle
::
memory
::
Alloc
(
context
.
GetPlace
(),
inputs_col_num
*
sizeof
(
int64_t
),
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
context
.
stream
())));
auto
*
restored
=
phi
::
backends
::
gpu
::
RestoreHostMemIfCapturingCUDAGraph
(
inputs_col
,
inputs_col_num
);
paddle
::
memory
::
Copy
(
context
.
GetPlace
(),
tmp_dev_ins_col_data
->
ptr
(),
paddle
::
platform
::
CPUPlace
(),
restored
,
inputs_col_num
*
sizeof
(
int64_t
),
context
.
stream
());
int64_t
*
dev_ins_col_data
=
static_cast
<
int64_t
*>
(
tmp_dev_ins_col_data
->
ptr
());
ConcatKernel_
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
dev_ins_data
,
dev_ins_col_data
,
static_cast
<
int
>
(
inputs_col_num
),
out_row
,
out_col
,
output
->
data
<
T
>
());
}
}
#undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE
}
#undef IMPL_CONCATE_CUDA_KERNEL_HELPER
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
// Prevent the pinned memory value from being covered and release the memory
// Prevent pinned memory from being covered and release the memory after
// after the launch kernel of the stream is executed (reapply pinned memory
// kernel launch of the stream is executed (reapply pinned memory next time)
// next time)
auto
*
data_alloc_released
=
data_alloc
.
release
();
auto
*
data_alloc_released
=
data_alloc
.
release
();
auto
*
col_alloc_released
=
col_alloc
.
release
();
auto
*
col_alloc_released
=
col_alloc
.
release
();
ctx
.
AddStreamCallback
([
data_alloc_released
,
col_alloc_released
]
{
context
.
AddStreamCallback
([
data_alloc_released
,
col_alloc_released
]
{
VLOG
(
4
)
<<
"Delete cuda pinned at "
<<
data_alloc_released
;
VLOG
(
4
)
<<
"Delete cuda pinned at "
<<
data_alloc_released
;
VLOG
(
4
)
<<
"Delete cuda pinned at "
<<
col_alloc_released
;
VLOG
(
4
)
<<
"Delete cuda pinned at "
<<
col_alloc_released
;
paddle
::
memory
::
allocation
::
Allocator
::
AllocationDeleter
(
paddle
::
memory
::
allocation
::
Allocator
::
AllocationDeleter
(
data_alloc_released
);
data_alloc_released
);
paddle
::
memory
::
allocation
::
Allocator
::
AllocationDeleter
(
paddle
::
memory
::
allocation
::
Allocator
::
AllocationDeleter
(
col_alloc_released
);
col_alloc_released
);
});
});
#endif
#endif
}
template
<
typename
T
>
struct
ConcatFunctor
<
phi
::
GPUContext
,
T
>
{
void
operator
()(
const
phi
::
GPUContext
&
context
,
const
std
::
vector
<
phi
::
DenseTensor
>&
input
,
int
axis
,
phi
::
DenseTensor
*
output
)
{
if
(
output
->
numel
()
<
std
::
numeric_limits
<
int32_t
>::
max
())
{
ConcatFunctorWithIndexType
<
T
,
int32_t
>
(
context
,
input
,
axis
,
output
);
}
else
{
ConcatFunctorWithIndexType
<
T
,
int64_t
>
(
context
,
input
,
axis
,
output
);
}
}
}
};
};
...
@@ -488,7 +526,7 @@ class SplitFunctor<phi::GPUContext, T> {
...
@@ -488,7 +526,7 @@ class SplitFunctor<phi::GPUContext, T> {
outputs_data
,
o_num
);
outputs_data
,
o_num
);
paddle
::
memory
::
Copy
(
context
.
GetPlace
(),
paddle
::
memory
::
Copy
(
context
.
GetPlace
(),
tmp_dev_outs_data
->
ptr
(),
tmp_dev_outs_data
->
ptr
(),
p
addle
::
platform
::
CPUPlace
(),
p
hi
::
CPUPlace
(),
restored
,
restored
,
o_num
*
sizeof
(
T
*
),
o_num
*
sizeof
(
T
*
),
context
.
stream
());
context
.
stream
());
...
@@ -539,7 +577,7 @@ class SplitFunctor<phi::GPUContext, T> {
...
@@ -539,7 +577,7 @@ class SplitFunctor<phi::GPUContext, T> {
outputs_cols
,
outputs_cols_num
);
outputs_cols
,
outputs_cols_num
);
paddle
::
memory
::
Copy
(
context
.
GetPlace
(),
paddle
::
memory
::
Copy
(
context
.
GetPlace
(),
tmp_dev_ins_col_data
->
ptr
(),
tmp_dev_ins_col_data
->
ptr
(),
p
addle
::
platform
::
CPUPlace
(),
p
hi
::
CPUPlace
(),
restored
,
restored
,
outputs_cols_num
*
sizeof
(
int64_t
),
outputs_cols_num
*
sizeof
(
int64_t
),
context
.
stream
());
context
.
stream
());
...
...
paddle/phi/kernels/gpu/stack_kernel.cu
浏览文件 @
0de94cd9
...
@@ -25,7 +25,9 @@ namespace phi {
...
@@ -25,7 +25,9 @@ namespace phi {
template
<
typename
IndexT
>
template
<
typename
IndexT
>
struct
DivmodWarpper
{
struct
DivmodWarpper
{
public:
public:
void
SetDivden
(
IndexT
dividen
)
{
divmoder
=
phi
::
funcs
::
FastDivMod
(
dividen
);
}
void
SetDivisor
(
IndexT
divisor
)
{
divmoder
=
phi
::
funcs
::
FastDivMod
(
divisor
);
}
__device__
inline
phi
::
funcs
::
FastDivMod
::
DivModT
div_mod
(
IndexT
val
)
{
__device__
inline
phi
::
funcs
::
FastDivMod
::
DivModT
div_mod
(
IndexT
val
)
{
return
divmoder
.
Divmod
(
val
);
return
divmoder
.
Divmod
(
val
);
}
}
...
@@ -39,7 +41,7 @@ struct DivmodWarpper<int64_t> {
...
@@ -39,7 +41,7 @@ struct DivmodWarpper<int64_t> {
public:
public:
using
DivModT
=
phi
::
AlignedVector
<
int64_t
,
2
>
;
using
DivModT
=
phi
::
AlignedVector
<
int64_t
,
2
>
;
void
SetDiv
den
(
int64_t
dividen
)
{
dividen_
=
dividen
;
}
void
SetDiv
isor
(
int64_t
divisor
)
{
dividen_
=
divisor
;
}
__device__
inline
DivModT
div_mod
(
int64_t
val
)
{
__device__
inline
DivModT
div_mod
(
int64_t
val
)
{
DivModT
data
;
DivModT
data
;
data
[
0
]
=
val
/
dividen_
;
data
[
0
]
=
val
/
dividen_
;
...
@@ -51,15 +53,14 @@ struct DivmodWarpper<int64_t> {
...
@@ -51,15 +53,14 @@ struct DivmodWarpper<int64_t> {
int64_t
dividen_
;
int64_t
dividen_
;
};
};
constexpr
int
kWarpperSize
=
64
;
template
<
typename
T
,
typename
IndexT
,
int
Size
>
template
<
typename
T
,
typename
IndexT
>
struct
PointerArray
:
public
DivmodWarpper
<
IndexT
>
{
struct
PointerArray
:
public
DivmodWarpper
<
IndexT
>
{
public:
public:
const
T
*
data
[
kWarpper
Size
];
const
T
*
data
[
Size
];
PointerArray
(
const
std
::
vector
<
const
DenseTensor
*>&
x
,
PointerArray
(
const
std
::
vector
<
const
DenseTensor
*>&
x
,
int
num
,
int
num
,
int64_t
dividen
)
{
IndexT
divisor
)
{
this
->
SetDiv
den
(
dividen
);
this
->
SetDiv
isor
(
divisor
);
for
(
auto
i
=
0
;
i
<
num
;
++
i
)
{
for
(
auto
i
=
0
;
i
<
num
;
++
i
)
{
data
[
i
]
=
x
[
i
]
->
data
<
T
>
();
data
[
i
]
=
x
[
i
]
->
data
<
T
>
();
}
}
...
@@ -69,33 +70,33 @@ struct PointerArray : public DivmodWarpper<IndexT> {
...
@@ -69,33 +70,33 @@ struct PointerArray : public DivmodWarpper<IndexT> {
template
<
typename
Context
,
typename
T
,
typename
IndexT
>
template
<
typename
Context
,
typename
T
,
typename
IndexT
>
struct
PointerToPointer
:
public
DivmodWarpper
<
IndexT
>
{
struct
PointerToPointer
:
public
DivmodWarpper
<
IndexT
>
{
public:
public:
T
**
data
;
T
**
data
{
nullptr
}
;
PointerToPointer
(
const
Context
&
ctx
,
PointerToPointer
(
const
Context
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
x
,
int
num
,
IndexT
num
,
int64_t
dividen
)
{
IndexT
divisor
,
this
->
SetDivden
(
dividen
);
paddle
::
memory
::
AllocationPtr
*
dev_ins_ptr
)
{
auto
byte_len
=
num
*
sizeof
(
T
*
);
this
->
SetDivisor
(
divisor
);
std
::
vector
<
const
T
*>
x_datas
(
num
);
std
::
vector
<
const
T
*>
x_datas
(
num
);
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
x_datas
[
i
]
=
x
[
i
]
->
data
<
T
>
();
x_datas
[
i
]
=
x
[
i
]
->
data
<
T
>
();
}
}
auto
tmp_x_data
=
paddle
::
memory
::
Alloc
(
*
dev_ins_ptr
=
paddle
::
memory
::
Alloc
(
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
byte_len
,
num
*
sizeof
(
T
*
)
,
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
ctx
.
stream
())));
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
ctx
.
stream
())));
paddle
::
memory
::
Copy
(
ctx
.
GetPlace
(),
paddle
::
memory
::
Copy
(
ctx
.
GetPlace
(),
tmp_x_data
->
ptr
(),
(
*
dev_ins_ptr
)
->
ptr
(),
phi
::
CPUPlace
(),
phi
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
x_datas
.
data
()),
reinterpret_cast
<
void
*>
(
x_datas
.
data
()),
x_datas
.
size
()
*
sizeof
(
T
*
),
num
*
sizeof
(
T
*
),
ctx
.
stream
());
ctx
.
stream
());
data
=
reinterpret_cast
<
T
**>
(
tmp_x_data
->
ptr
());
data
=
reinterpret_cast
<
T
**>
(
(
*
dev_ins_ptr
)
->
ptr
());
}
}
};
};
template
<
typename
T
,
typename
IndexT
,
typename
W
ar
pT
>
template
<
typename
T
,
typename
IndexT
,
typename
W
ra
pT
>
__global__
void
StackCUDAKernel
(
W
ar
pT
input_warpper
,
__global__
void
StackCUDAKernel
(
W
ra
pT
input_warpper
,
IndexT
split_size
,
IndexT
split_size
,
IndexT
rows
,
IndexT
rows
,
IndexT
cols
,
IndexT
cols
,
...
@@ -117,14 +118,56 @@ __global__ void StackCUDAKernel(WarpT input_warpper,
...
@@ -117,14 +118,56 @@ __global__ void StackCUDAKernel(WarpT input_warpper,
}
}
}
}
template
<
typename
T
,
typename
IndexT
,
typename
Context
>
void
LaunchStackCUDAKernelWithIndexType
(
const
Context
&
ctx
,
const
IndexT
x_col
,
const
IndexT
x_row
,
const
IndexT
out_col
,
const
phi
::
backends
::
gpu
::
GpuLaunchConfig
&
cfg
,
const
std
::
vector
<
const
DenseTensor
*>&
x
,
T
*
dst_data
)
{
int
num
=
static_cast
<
int
>
(
x
.
size
());
#define IMPL_STACK_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerArray<T, IndexT, size_> ptr_array(x, num, x_col); \
__VA_ARGS__; \
} break;
#define IMPL_STACK_CUDA_KERNEL_HELPER(...) \
IMPL_STACK_CUDA_KERNEL_CASE(2, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(64, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(128, ##__VA_ARGS__);
switch
(
phi
::
backends
::
gpu
::
RoundToNextHighPowOfTwo
(
num
,
4
))
{
IMPL_STACK_CUDA_KERNEL_HELPER
(
StackCUDAKernel
<
T
,
IndexT
,
decltype
(
ptr_array
)
>
<<<
cfg
.
block_per_grid
,
cfg
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
ptr_array
,
x_col
,
x_row
,
out_col
,
dst_data
));
default:
{
paddle
::
memory
::
AllocationPtr
dev_ins_ptr
{
nullptr
};
PointerToPointer
<
Context
,
T
,
IndexT
>
ptr_array
(
ctx
,
x
,
num
,
x_col
,
&
dev_ins_ptr
);
StackCUDAKernel
<
T
,
IndexT
,
decltype
(
ptr_array
)
>
<<<
cfg
.
block_per_grid
,
cfg
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
ptr_array
,
x_col
,
x_row
,
out_col
,
dst_data
);
}
}
#undef IMPL_STACK_CUDA_KERNEL_HELPER
#undef IMPL_STACK_CUDA_KERNEL_CASE
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
StackKernel
(
const
Context
&
dev_ctx
,
void
StackKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
x
,
int
axis
,
int
axis
,
DenseTensor
*
out
)
{
DenseTensor
*
out
)
{
if
(
axis
<
0
)
axis
+=
(
x
[
0
]
->
dims
().
size
()
+
1
);
if
(
axis
<
0
)
axis
+=
(
x
[
0
]
->
dims
().
size
()
+
1
);
int
n
=
static_cast
<
int
>
(
x
.
size
());
int
n
um
=
static_cast
<
int
>
(
x
.
size
());
T
*
y
_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
T
*
dst
_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
// Split x dim from axis to matrix
// Split x dim from axis to matrix
int64_t
x_row
=
1
,
x_col
=
1
;
int64_t
x_row
=
1
,
x_col
=
1
;
...
@@ -132,40 +175,17 @@ void StackKernel(const Context& dev_ctx,
...
@@ -132,40 +175,17 @@ void StackKernel(const Context& dev_ctx,
x_row
*=
x
[
0
]
->
dims
()[
i
];
x_row
*=
x
[
0
]
->
dims
()[
i
];
}
}
x_col
=
x
[
0
]
->
numel
()
/
x_row
;
x_col
=
x
[
0
]
->
numel
()
/
x_row
;
int64_t
out_col
=
x_col
*
n
;
int64_t
out_col
=
x_col
*
n
um
;
auto
config
=
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig2D
(
dev_ctx
,
out_col
,
x_row
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig2D
(
dev_ctx
,
out_col
,
x_row
);
#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \
if
(
out
->
numel
()
<
std
::
numeric_limits
<
int32_t
>::
max
())
{
StackCUDAKernel<T, index_t, decltype(input_warpper)> \
LaunchStackCUDAKernelWithIndexType
<
T
,
int32_t
,
Context
>
(
<<<config.block_per_grid, \
dev_ctx
,
x_col
,
x_row
,
out_col
,
config
,
x
,
dst_data
);
config.thread_per_block, \
0, \
dev_ctx.stream()>>>(input_warpper, \
static_cast<index_t>(x_col), \
static_cast<index_t>(x_row), \
static_cast<index_t>(out_col), \
y_data);
bool
use_int32
=
out
->
numel
()
<
std
::
numeric_limits
<
int32_t
>::
max
();
if
(
n
<=
kWarpperSize
)
{
if
(
use_int32
)
{
PointerArray
<
T
,
int32_t
>
ptr_array
(
x
,
n
,
x_col
);
IMPL_STACK_CUDA_KERNEL
(
int32_t
,
ptr_array
);
}
else
{
PointerArray
<
T
,
int64_t
>
ptr_array
(
x
,
n
,
x_col
);
IMPL_STACK_CUDA_KERNEL
(
int64_t
,
ptr_array
);
}
}
else
{
}
else
{
if
(
use_int32
)
{
LaunchStackCUDAKernelWithIndexType
<
T
,
int64_t
,
Context
>
(
PointerToPointer
<
Context
,
T
,
int32_t
>
ptr_array
(
dev_ctx
,
x
,
n
,
x_col
);
dev_ctx
,
x_col
,
x_row
,
out_col
,
config
,
x
,
dst_data
);
IMPL_STACK_CUDA_KERNEL
(
int32_t
,
ptr_array
);
}
else
{
PointerToPointer
<
Context
,
T
,
int64_t
>
ptr_array
(
dev_ctx
,
x
,
n
,
x_col
);
IMPL_STACK_CUDA_KERNEL
(
int64_t
,
ptr_array
);
}
}
}
#undef IMPL_STACK_CUDA_KERNEL
}
}
}
// namespace phi
}
// namespace phi
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录