Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0de94cd9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0de94cd9
编写于
1月 03, 2023
作者:
L
limingshu
提交者:
GitHub
1月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
H2D data transfer optimization for concat kernel (#49040)
上级
f484a61e
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
342 addition
and
279 deletion
+342
-279
paddle/phi/backends/gpu/gpu_launch_config.h
paddle/phi/backends/gpu/gpu_launch_config.h
+8
-3
paddle/phi/kernels/funcs/concat_and_split_functor.cu
paddle/phi/kernels/funcs/concat_and_split_functor.cu
+264
-226
paddle/phi/kernels/gpu/stack_kernel.cu
paddle/phi/kernels/gpu/stack_kernel.cu
+70
-50
未找到文件。
paddle/phi/backends/gpu/gpu_launch_config.h
浏览文件 @
0de94cd9
...
...
@@ -53,20 +53,25 @@ inline T DivUp(T a, T b) {
// https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
// for round integer value into next highest power of 2.
inline
int64_t
RoundTo
PowerOfTwo
(
int64_t
n
)
{
inline
int64_t
RoundTo
NextHighPowOfTwo
(
int64_t
n
,
int64_t
min_val
=
1
)
{
n
--
;
n
|=
(
n
>>
1
);
n
|=
(
n
>>
2
);
n
|=
(
n
>>
4
);
n
|=
(
n
>>
8
);
n
|=
(
n
>>
16
);
int64_t
min_val
=
32
;
return
std
::
max
(
min_val
,
(
n
+
1
));
}
inline
int64_t
RoundToPowerOfTwo
(
int64_t
n
)
{
constexpr
int64_t
min_val
=
32
;
int64_t
num
=
RoundToNextHighPowOfTwo
(
n
,
min_val
);
#ifdef __HIPCC__
int64_t
max_val
=
256
;
#else
int64_t
max_val
=
1024
;
#endif
return
std
::
min
(
max_val
,
std
::
max
(
min_val
,
(
n
+
1
))
);
return
std
::
min
(
max_val
,
num
);
}
#ifdef WITH_NV_JETSON
...
...
paddle/phi/kernels/funcs/concat_and_split_functor.cu
浏览文件 @
0de94cd9
此差异已折叠。
点击以展开。
paddle/phi/kernels/gpu/stack_kernel.cu
浏览文件 @
0de94cd9
...
...
@@ -25,7 +25,9 @@ namespace phi {
template
<
typename
IndexT
>
struct
DivmodWarpper
{
public:
void
SetDivden
(
IndexT
dividen
)
{
divmoder
=
phi
::
funcs
::
FastDivMod
(
dividen
);
}
void
SetDivisor
(
IndexT
divisor
)
{
divmoder
=
phi
::
funcs
::
FastDivMod
(
divisor
);
}
__device__
inline
phi
::
funcs
::
FastDivMod
::
DivModT
div_mod
(
IndexT
val
)
{
return
divmoder
.
Divmod
(
val
);
}
...
...
@@ -39,7 +41,7 @@ struct DivmodWarpper<int64_t> {
public:
using
DivModT
=
phi
::
AlignedVector
<
int64_t
,
2
>
;
void
SetDiv
den
(
int64_t
dividen
)
{
dividen_
=
dividen
;
}
void
SetDiv
isor
(
int64_t
divisor
)
{
dividen_
=
divisor
;
}
__device__
inline
DivModT
div_mod
(
int64_t
val
)
{
DivModT
data
;
data
[
0
]
=
val
/
dividen_
;
...
...
@@ -51,15 +53,14 @@ struct DivmodWarpper<int64_t> {
int64_t
dividen_
;
};
constexpr
int
kWarpperSize
=
64
;
template
<
typename
T
,
typename
IndexT
>
template
<
typename
T
,
typename
IndexT
,
int
Size
>
struct
PointerArray
:
public
DivmodWarpper
<
IndexT
>
{
public:
const
T
*
data
[
kWarpper
Size
];
const
T
*
data
[
Size
];
PointerArray
(
const
std
::
vector
<
const
DenseTensor
*>&
x
,
int
num
,
int64_t
dividen
)
{
this
->
SetDiv
den
(
dividen
);
IndexT
divisor
)
{
this
->
SetDiv
isor
(
divisor
);
for
(
auto
i
=
0
;
i
<
num
;
++
i
)
{
data
[
i
]
=
x
[
i
]
->
data
<
T
>
();
}
...
...
@@ -69,33 +70,33 @@ struct PointerArray : public DivmodWarpper<IndexT> {
template
<
typename
Context
,
typename
T
,
typename
IndexT
>
struct
PointerToPointer
:
public
DivmodWarpper
<
IndexT
>
{
public:
T
**
data
;
T
**
data
{
nullptr
}
;
PointerToPointer
(
const
Context
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
x
,
int
num
,
int64_t
dividen
)
{
this
->
SetDivden
(
dividen
);
auto
byte_len
=
num
*
sizeof
(
T
*
);
IndexT
num
,
IndexT
divisor
,
paddle
::
memory
::
AllocationPtr
*
dev_ins_ptr
)
{
this
->
SetDivisor
(
divisor
);
std
::
vector
<
const
T
*>
x_datas
(
num
);
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
x_datas
[
i
]
=
x
[
i
]
->
data
<
T
>
();
}
auto
tmp_x_data
=
paddle
::
memory
::
Alloc
(
*
dev_ins_ptr
=
paddle
::
memory
::
Alloc
(
ctx
.
GetPlace
(),
byte_len
,
num
*
sizeof
(
T
*
)
,
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
ctx
.
stream
())));
paddle
::
memory
::
Copy
(
ctx
.
GetPlace
(),
tmp_x_data
->
ptr
(),
(
*
dev_ins_ptr
)
->
ptr
(),
phi
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
x_datas
.
data
()),
x_datas
.
size
()
*
sizeof
(
T
*
),
num
*
sizeof
(
T
*
),
ctx
.
stream
());
data
=
reinterpret_cast
<
T
**>
(
tmp_x_data
->
ptr
());
data
=
reinterpret_cast
<
T
**>
(
(
*
dev_ins_ptr
)
->
ptr
());
}
};
template
<
typename
T
,
typename
IndexT
,
typename
W
ar
pT
>
__global__
void
StackCUDAKernel
(
W
ar
pT
input_warpper
,
template
<
typename
T
,
typename
IndexT
,
typename
W
ra
pT
>
__global__
void
StackCUDAKernel
(
W
ra
pT
input_warpper
,
IndexT
split_size
,
IndexT
rows
,
IndexT
cols
,
...
...
@@ -117,14 +118,56 @@ __global__ void StackCUDAKernel(WarpT input_warpper,
}
}
template
<
typename
T
,
typename
IndexT
,
typename
Context
>
void
LaunchStackCUDAKernelWithIndexType
(
const
Context
&
ctx
,
const
IndexT
x_col
,
const
IndexT
x_row
,
const
IndexT
out_col
,
const
phi
::
backends
::
gpu
::
GpuLaunchConfig
&
cfg
,
const
std
::
vector
<
const
DenseTensor
*>&
x
,
T
*
dst_data
)
{
int
num
=
static_cast
<
int
>
(
x
.
size
());
#define IMPL_STACK_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerArray<T, IndexT, size_> ptr_array(x, num, x_col); \
__VA_ARGS__; \
} break;
#define IMPL_STACK_CUDA_KERNEL_HELPER(...) \
IMPL_STACK_CUDA_KERNEL_CASE(2, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(64, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(128, ##__VA_ARGS__);
switch
(
phi
::
backends
::
gpu
::
RoundToNextHighPowOfTwo
(
num
,
4
))
{
IMPL_STACK_CUDA_KERNEL_HELPER
(
StackCUDAKernel
<
T
,
IndexT
,
decltype
(
ptr_array
)
>
<<<
cfg
.
block_per_grid
,
cfg
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
ptr_array
,
x_col
,
x_row
,
out_col
,
dst_data
));
default:
{
paddle
::
memory
::
AllocationPtr
dev_ins_ptr
{
nullptr
};
PointerToPointer
<
Context
,
T
,
IndexT
>
ptr_array
(
ctx
,
x
,
num
,
x_col
,
&
dev_ins_ptr
);
StackCUDAKernel
<
T
,
IndexT
,
decltype
(
ptr_array
)
>
<<<
cfg
.
block_per_grid
,
cfg
.
thread_per_block
,
0
,
ctx
.
stream
()
>>>
(
ptr_array
,
x_col
,
x_row
,
out_col
,
dst_data
);
}
}
#undef IMPL_STACK_CUDA_KERNEL_HELPER
#undef IMPL_STACK_CUDA_KERNEL_CASE
}
template
<
typename
T
,
typename
Context
>
void
StackKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
x
,
int
axis
,
DenseTensor
*
out
)
{
if
(
axis
<
0
)
axis
+=
(
x
[
0
]
->
dims
().
size
()
+
1
);
int
n
=
static_cast
<
int
>
(
x
.
size
());
T
*
y
_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
int
n
um
=
static_cast
<
int
>
(
x
.
size
());
T
*
dst
_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
// Split x dim from axis to matrix
int64_t
x_row
=
1
,
x_col
=
1
;
...
...
@@ -132,40 +175,17 @@ void StackKernel(const Context& dev_ctx,
x_row
*=
x
[
0
]
->
dims
()[
i
];
}
x_col
=
x
[
0
]
->
numel
()
/
x_row
;
int64_t
out_col
=
x_col
*
n
;
int64_t
out_col
=
x_col
*
n
um
;
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig2D
(
dev_ctx
,
out_col
,
x_row
);
#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \
StackCUDAKernel<T, index_t, decltype(input_warpper)> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
dev_ctx.stream()>>>(input_warpper, \
static_cast<index_t>(x_col), \
static_cast<index_t>(x_row), \
static_cast<index_t>(out_col), \
y_data);
bool
use_int32
=
out
->
numel
()
<
std
::
numeric_limits
<
int32_t
>::
max
();
if
(
n
<=
kWarpperSize
)
{
if
(
use_int32
)
{
PointerArray
<
T
,
int32_t
>
ptr_array
(
x
,
n
,
x_col
);
IMPL_STACK_CUDA_KERNEL
(
int32_t
,
ptr_array
);
}
else
{
PointerArray
<
T
,
int64_t
>
ptr_array
(
x
,
n
,
x_col
);
IMPL_STACK_CUDA_KERNEL
(
int64_t
,
ptr_array
);
}
if
(
out
->
numel
()
<
std
::
numeric_limits
<
int32_t
>::
max
())
{
LaunchStackCUDAKernelWithIndexType
<
T
,
int32_t
,
Context
>
(
dev_ctx
,
x_col
,
x_row
,
out_col
,
config
,
x
,
dst_data
);
}
else
{
if
(
use_int32
)
{
PointerToPointer
<
Context
,
T
,
int32_t
>
ptr_array
(
dev_ctx
,
x
,
n
,
x_col
);
IMPL_STACK_CUDA_KERNEL
(
int32_t
,
ptr_array
);
}
else
{
PointerToPointer
<
Context
,
T
,
int64_t
>
ptr_array
(
dev_ctx
,
x
,
n
,
x_col
);
IMPL_STACK_CUDA_KERNEL
(
int64_t
,
ptr_array
);
}
LaunchStackCUDAKernelWithIndexType
<
T
,
int64_t
,
Context
>
(
dev_ctx
,
x_col
,
x_row
,
out_col
,
config
,
x
,
dst_data
);
}
#undef IMPL_STACK_CUDA_KERNEL
}
}
// namespace phi
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录