Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c7855125
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c7855125
编写于
5月 10, 2022
作者:
S
shixingbo
提交者:
GitHub
5月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
broadcast_add kp performance optimization (#42097)
上级
81078a88
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
880 addition
and
43 deletion
+880
-43
paddle/phi/kernels/funcs/broadcast_function.h
paddle/phi/kernels/funcs/broadcast_function.h
+107
-31
paddle/phi/kernels/funcs/elementwise_base.h
paddle/phi/kernels/funcs/elementwise_base.h
+49
-7
paddle/phi/kernels/primitive/compute_primitives.h
paddle/phi/kernels/primitive/compute_primitives.h
+14
-0
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
+14
-0
paddle/phi/kernels/primitive/datamover_primitives.h
paddle/phi/kernels/primitive/datamover_primitives.h
+104
-0
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
+591
-5
paddle/phi/kernels/primitive/kernel_primitives.h
paddle/phi/kernels/primitive/kernel_primitives.h
+1
-0
未找到文件。
paddle/phi/kernels/funcs/broadcast_function.h
浏览文件 @
c7855125
...
@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
...
@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
}
}
}
}
template
<
typename
T
,
int
VecSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
LoadData
(
T
*
dst
,
const
_ptr_
T
*
src
,
uint32_t
block_offset
,
const
kps
::
details
::
BroadcastConfig
<
Rank
>
&
config
,
int
numel
,
int
num
,
int
need_broadcast
,
int
read_lens
)
{
// numel : whole num of output
// num: how many data will be deal with in this time
if
(
need_broadcast
)
{
kps
::
ReadDataBc
<
T
,
VecSize
,
1
,
1
,
Rank
,
IsBoundary
>
(
dst
,
src
,
block_offset
,
config
,
numel
,
read_lens
);
}
else
{
kps
::
ReadData
<
T
,
VecSize
,
1
,
1
,
IsBoundary
>
(
dst
,
src
+
block_offset
,
num
,
read_lens
);
}
}
template
<
typename
InT
,
template
<
typename
InT
,
typename
OutT
,
typename
OutT
,
typename
Functor
,
typename
Functor
,
...
@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
...
@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
const
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
&
configs
,
const
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
&
configs
,
int
num
,
int
num
,
int
block_offset
,
int
block_offset
,
int
read_lens
,
Functor
func
)
{
Functor
func
)
{
InT
args
[
Arity
][
VecSize
];
__simd__
InT
args
[
Arity
][
VecSize
];
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
__simd__
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
));
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
)
,
read_lens
);
LoadData
<
InT
,
VecSize
,
Rank
,
IsBoundary
>
(
args
[
i
],
LoadData
<
InT
,
VecSize
,
Rank
,
IsBoundary
>
(
args
[
i
],
ins
[
i
],
ins
[
i
],
block_offset
,
block_offset
,
configs
[
i
],
configs
[
i
],
numel
,
numel
,
num
,
num
,
use_broadcast
[
i
]);
use_broadcast
[
i
],
read_lens
);
}
}
constexpr
bool
kCallElementwiseAny
=
constexpr
bool
kCallElementwiseAny
=
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
...
@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
...
@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
Functor
,
Functor
,
Arity
,
Arity
,
kCallElementwiseAny
>
()(
kCallElementwiseAny
>
()(
func
,
args
,
result
);
func
,
args
,
result
,
read_lens
);
phi
::
funcs
::
phi
::
funcs
::
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
ElementwiseWriteDataCallerBc
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
block_offset
,
num
);
outs
,
result
,
block_offset
,
num
,
read_lens
);
}
}
template
<
typename
InT
,
template
<
typename
InT
,
...
@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
configs
,
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
configs
,
int
main_offset
,
int
main_offset
,
int
tail_tid
,
int
tail_tid
,
int
read_lens
,
Functor
func
)
{
Functor
func
)
{
int
block_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
VecSize
;
int
block_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
read_lens
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
VecSize
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
read_lens
;
#ifdef PADDLE_WITH_XPU_KP
#ifdef PADDLE_WITH_XPU_KP
for
(;
block_offset
<
main_offset
;
block_offset
+=
stride
)
{
for
(;
block_offset
<
main_offset
;
block_offset
+=
stride
)
{
...
@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
use_broadcast
,
use_broadcast
,
numel
,
numel
,
configs
,
configs
,
BLOCK_NUM_X
*
VecSize
,
BLOCK_NUM_X
*
read_lens
,
block_offset
,
block_offset
,
read_lens
,
func
);
func
);
}
}
int
num
=
numel
-
block_offset
;
int
num
=
numel
-
block_offset
;
...
@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts
,
NumOuts
,
VecSize
,
VecSize
,
Rank
,
Rank
,
true
>
(
true
>
(
ins
,
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
num
,
block_offset
,
func
);
outs
,
use_broadcast
,
numel
,
configs
,
num
,
block_offset
,
read_lens
,
func
);
}
}
#else
#else
if
(
block_offset
<
main_offset
)
{
if
(
block_offset
<
main_offset
)
{
...
@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
configs
,
configs
,
BLOCK_NUM_X
*
VecSize
,
BLOCK_NUM_X
*
VecSize
,
block_offset
,
block_offset
,
read_lens
,
func
);
func
);
}
else
{
}
else
{
VectorizedBroadcastKernelImpl
<
InT
,
VectorizedBroadcastKernelImpl
<
InT
,
...
@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts
,
NumOuts
,
VecSize
,
VecSize
,
Rank
,
Rank
,
true
>
(
true
>
(
ins
,
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
tail_tid
,
block_offset
,
func
);
outs
,
use_broadcast
,
numel
,
configs
,
tail_tid
,
block_offset
,
read_lens
,
func
);
}
}
#endif
#endif
}
}
...
@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
...
@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
use_broadcast
[
i
]
=
(
ins
[
i
]
->
numel
()
!=
numel
);
use_broadcast
[
i
]
=
(
ins
[
i
]
->
numel
()
!=
numel
);
ins_data
[
i
]
=
(
const
_ptr_
InT
*
)(
ins
[
i
]
->
data
<
InT
>
());
ins_data
[
i
]
=
(
const
_ptr_
InT
*
)(
ins
[
i
]
->
data
<
InT
>
());
#ifdef PADDLE_WITH_XPU_KP
if
(
i
==
0
)
{
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
0
],
merge_dims
.
in_dims
[
1
],
merge_dims
.
dim_size
);
}
else
if
(
i
==
1
)
{
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
1
],
merge_dims
.
in_dims
[
0
],
merge_dims
.
dim_size
);
}
#else
if
(
use_broadcast
[
i
])
{
if
(
use_broadcast
[
i
])
{
// get the broadcast config,
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// if data shape is[m, n], then you should set data_dim = {n, m}
...
@@ -399,28 +452,50 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
...
@@ -399,28 +452,50 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
i
],
merge_dims
.
dim_size
);
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
i
],
merge_dims
.
dim_size
);
}
}
#endif
}
}
#ifdef PADDLE_WITH_XPU_KP
#ifdef PADDLE_WITH_XPU_KP
const
int
threads
=
64
;
const
int
threads
=
64
;
const
int
blocks
=
8
;
const
int
blocks
=
8
;
int
main_offset
=
(
numel
/
(
VecSize
*
threads
))
*
VecSize
*
threads
;
int
read_lens
=
configs
[
0
].
buf_len
;
int
tail_tid
=
numel
%
(
VecSize
*
threads
);
int
main_offset
=
(
numel
/
(
read_lens
*
threads
))
*
read_lens
*
threads
;
int
tail_tid
=
numel
%
(
read_lens
*
threads
);
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
VectorizedBroadcastKernel
<
InT
,
if
(
configs
[
0
].
cmp_type
!=
kps
::
details
::
OptType
::
CanNotOptimize
)
{
OutT
,
main_offset
=
numel
;
Functor
,
VectorizedBroadcastKernel
<
InT
,
Arity
,
OutT
,
NumOuts
,
Functor
,
VecSize
,
Arity
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
NumOuts
,
outs_data
,
512
,
use_broadcast
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
numel
,
outs_data
,
configs
,
use_broadcast
,
main_offset
,
numel
,
tail_tid
,
configs
,
func
);
main_offset
,
tail_tid
,
read_lens
,
func
);
}
else
{
VectorizedBroadcastKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
256
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
outs_data
,
use_broadcast
,
numel
,
configs
,
main_offset
,
tail_tid
,
read_lens
,
func
);
}
#else
#else
const
int
threads
=
256
;
const
int
threads
=
256
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
...
@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
...
@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs
,
configs
,
main_offset
,
main_offset
,
tail_tid
,
tail_tid
,
VecSize
,
func
);
func
);
#endif
#endif
}
}
...
...
paddle/phi/kernels/funcs/elementwise_base.h
浏览文件 @
c7855125
...
@@ -577,14 +577,16 @@ template <typename InT,
...
@@ -577,14 +577,16 @@ template <typename InT,
struct
ElementwisePrimitiveCaller
{
struct
ElementwisePrimitiveCaller
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
);
OutT
*
result
,
int
read_lens
);
};
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
>
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
Arity
,
true
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
Arity
,
true
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseAny
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Arity
,
Functor
>
(
kps
::
ElementwiseAny
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Arity
,
Functor
>
(
result
,
args
,
func
);
result
,
args
,
func
);
}
}
...
@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
...
@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
0
,
false
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
0
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseConstant
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
func
);
kps
::
ElementwiseConstant
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
func
);
}
}
};
};
...
@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
...
@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
1
,
false
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
1
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
func
);
result
,
args
[
0
],
func
);
}
}
...
@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
...
@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
2
,
false
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
2
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
func
);
result
,
args
[
0
],
args
[
1
],
func
,
read_lens
);
}
}
};
};
...
@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
...
@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
3
,
false
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
3
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
}
}
...
@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
...
@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
}
}
};
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
,
int
NumOuts
>
struct
ElementwiseWriteDataCallerBc
{
__device__
__forceinline__
void
operator
()(
phi
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
ConditionalT
<
OutT
,
NumOuts
>
src
[
VecSize
],
int
block_offset
,
int
num
,
int
read_lens
)
{
OutT
dst
[
NumOuts
][
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
read_lens
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NumOuts
;
++
j
)
{
dst
[
j
][
i
]
=
(
src
[
i
])[
j
];
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
i
]
+
block_offset
,
dst
[
i
],
num
,
read_lens
);
}
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
>
struct
ElementwiseWriteDataCallerBc
<
OutT
,
VecSize
,
IsBoundary
,
1
>
{
__device__
__forceinline__
void
operator
()(
phi
::
Array
<
_ptr_
OutT
*
,
1
>
outs
,
OutT
src
[
VecSize
],
int
block_offset
,
int
num
,
int
read_lens
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
0
]
+
block_offset
,
src
,
num
,
read_lens
);
}
};
template
<
typename
OutT
,
template
<
typename
OutT
,
typename
Functor
,
typename
Functor
,
int
Arity
,
int
Arity
,
...
...
paddle/phi/kernels/primitive/compute_primitives.h
浏览文件 @
c7855125
...
@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
...
@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
}
}
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
,
int
read_lens
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
* are the same.
...
...
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
浏览文件 @
c7855125
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/simd_header.h"
namespace
phi
{
namespace
phi
{
namespace
kps
{
namespace
kps
{
...
@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
...
@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
}
}
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
,
int
read_lens
)
{
for
(
int
idx
=
0
;
idx
<
read_lens
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
* are the same.
...
...
paddle/phi/kernels/primitive/datamover_primitives.h
浏览文件 @
c7855125
...
@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
...
@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
}
}
}
}
template
<
typename
T
,
int
NX
>
__device__
__forceinline__
void
Init
(
T
*
dst
,
T
init_data
,
int
read_lens
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NX
;
i
++
)
{
dst
[
i
]
=
init_data
;
}
}
/**
/**
* The difference from the above function is that
* The difference from the above function is that
* it supports different data types of inputs.
* it supports different data types of inputs.
...
@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
...
@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
}
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
num
,
int
read_lens
)
{
if
(
IsBoundary
)
{
// blockDim.x * NX > num
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
dst
[
idx
]
=
src
[
thread_offset
+
idx
];
}
}
}
else
{
// blockDim,x * NX < num
constexpr
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
constexpr
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
const
VecType
*
vec_input
=
reinterpret_cast
<
const
VecType
*>
(
src
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kVectorsPerThread
;
++
i
)
{
vec_temp
[
i
]
=
vec_input
[
thread_offset
+
i
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
*
(
reinterpret_cast
<
T
*>
(
vec_temp
)
+
idx
);
}
}
}
}
/**
/**
* @brief Read 1D data from global memory to register. The difference
* @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs.
* from the above function is that it supports different data types of inputs.
...
@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst,
...
@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst,
}
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
WriteData
(
T
*
dst
,
T
*
__restrict__
src
,
int
num
,
int
read_lens
)
{
if
(
IsBoundary
)
{
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
thread_offset
+
idx
)
<
num
)
{
dst
[
thread_offset
+
idx
]
=
src
[
idx
];
}
}
}
else
{
// Vector type
constexpr
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
constexpr
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
VecType
*
vec_dst
=
reinterpret_cast
<
VecType
*>
(
dst
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
kVectorsPerThread
;
++
idx
)
{
vec_temp
[
idx
]
=
*
(
reinterpret_cast
<
VecType
*>
(
src
)
+
idx
);
vec_dst
[
thread_offset
+
idx
]
=
vec_temp
[
idx
];
}
}
}
/**
/**
* @brief Write 2D data from register to global memory according to Tx type, and
* @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type.
* store it as Ty type.
...
@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc(
...
@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc(
}
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
,
int
read_lens
)
{
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
index_src
=
0
;
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
index_output
=
thread_offset
+
nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
index_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
}
dst
[
nx
]
=
src
[
index_src
];
}
}
/**
/**
* @brief Initialize register with data index.
* @brief Initialize register with data index.
*
*
...
...
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
浏览文件 @
c7855125
此差异已折叠。
点击以展开。
paddle/phi/kernels/primitive/kernel_primitives.h
浏览文件 @
c7855125
...
@@ -46,6 +46,7 @@
...
@@ -46,6 +46,7 @@
#define KPStream gpuStream_t
#define KPStream gpuStream_t
#define KPDevice phi::GPUContext
#define KPDevice phi::GPUContext
#define _ptr_
#define _ptr_
#define __simd__
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y
#define THREAD_ID_Y threadIdx.y
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录