Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c7855125
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c7855125
编写于
5月 10, 2022
作者:
S
shixingbo
提交者:
GitHub
5月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
broadcast_add kp performance optimization (#42097)
上级
81078a88
变更
7
展开全部
显示空白变更内容
内联
并排
Showing
7 changed file
with
880 addition
and
43 deletion
+880
-43
paddle/phi/kernels/funcs/broadcast_function.h
paddle/phi/kernels/funcs/broadcast_function.h
+107
-31
paddle/phi/kernels/funcs/elementwise_base.h
paddle/phi/kernels/funcs/elementwise_base.h
+49
-7
paddle/phi/kernels/primitive/compute_primitives.h
paddle/phi/kernels/primitive/compute_primitives.h
+14
-0
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
+14
-0
paddle/phi/kernels/primitive/datamover_primitives.h
paddle/phi/kernels/primitive/datamover_primitives.h
+104
-0
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
+591
-5
paddle/phi/kernels/primitive/kernel_primitives.h
paddle/phi/kernels/primitive/kernel_primitives.h
+1
-0
未找到文件。
paddle/phi/kernels/funcs/broadcast_function.h
浏览文件 @
c7855125
...
@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
...
@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
}
}
}
}
template
<
typename
T
,
int
VecSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
LoadData
(
T
*
dst
,
const
_ptr_
T
*
src
,
uint32_t
block_offset
,
const
kps
::
details
::
BroadcastConfig
<
Rank
>
&
config
,
int
numel
,
int
num
,
int
need_broadcast
,
int
read_lens
)
{
// numel : whole num of output
// num: how many data will be deal with in this time
if
(
need_broadcast
)
{
kps
::
ReadDataBc
<
T
,
VecSize
,
1
,
1
,
Rank
,
IsBoundary
>
(
dst
,
src
,
block_offset
,
config
,
numel
,
read_lens
);
}
else
{
kps
::
ReadData
<
T
,
VecSize
,
1
,
1
,
IsBoundary
>
(
dst
,
src
+
block_offset
,
num
,
read_lens
);
}
}
template
<
typename
InT
,
template
<
typename
InT
,
typename
OutT
,
typename
OutT
,
typename
Functor
,
typename
Functor
,
...
@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
...
@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
const
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
&
configs
,
const
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
&
configs
,
int
num
,
int
num
,
int
block_offset
,
int
block_offset
,
int
read_lens
,
Functor
func
)
{
Functor
func
)
{
InT
args
[
Arity
][
VecSize
];
__simd__
InT
args
[
Arity
][
VecSize
];
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
__simd__
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
));
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
)
,
read_lens
);
LoadData
<
InT
,
VecSize
,
Rank
,
IsBoundary
>
(
args
[
i
],
LoadData
<
InT
,
VecSize
,
Rank
,
IsBoundary
>
(
args
[
i
],
ins
[
i
],
ins
[
i
],
block_offset
,
block_offset
,
configs
[
i
],
configs
[
i
],
numel
,
numel
,
num
,
num
,
use_broadcast
[
i
]);
use_broadcast
[
i
],
read_lens
);
}
}
constexpr
bool
kCallElementwiseAny
=
constexpr
bool
kCallElementwiseAny
=
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
...
@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
...
@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
Functor
,
Functor
,
Arity
,
Arity
,
kCallElementwiseAny
>
()(
kCallElementwiseAny
>
()(
func
,
args
,
result
);
func
,
args
,
result
,
read_lens
);
phi
::
funcs
::
phi
::
funcs
::
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
ElementwiseWriteDataCallerBc
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
block_offset
,
num
);
outs
,
result
,
block_offset
,
num
,
read_lens
);
}
}
template
<
typename
InT
,
template
<
typename
InT
,
...
@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
configs
,
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
configs
,
int
main_offset
,
int
main_offset
,
int
tail_tid
,
int
tail_tid
,
int
read_lens
,
Functor
func
)
{
Functor
func
)
{
int
block_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
VecSize
;
int
block_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
read_lens
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
VecSize
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
read_lens
;
#ifdef PADDLE_WITH_XPU_KP
#ifdef PADDLE_WITH_XPU_KP
for
(;
block_offset
<
main_offset
;
block_offset
+=
stride
)
{
for
(;
block_offset
<
main_offset
;
block_offset
+=
stride
)
{
...
@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
use_broadcast
,
use_broadcast
,
numel
,
numel
,
configs
,
configs
,
BLOCK_NUM_X
*
VecSize
,
BLOCK_NUM_X
*
read_lens
,
block_offset
,
block_offset
,
read_lens
,
func
);
func
);
}
}
int
num
=
numel
-
block_offset
;
int
num
=
numel
-
block_offset
;
...
@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts
,
NumOuts
,
VecSize
,
VecSize
,
Rank
,
Rank
,
true
>
(
true
>
(
ins
,
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
num
,
block_offset
,
func
);
outs
,
use_broadcast
,
numel
,
configs
,
num
,
block_offset
,
read_lens
,
func
);
}
}
#else
#else
if
(
block_offset
<
main_offset
)
{
if
(
block_offset
<
main_offset
)
{
...
@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
configs
,
configs
,
BLOCK_NUM_X
*
VecSize
,
BLOCK_NUM_X
*
VecSize
,
block_offset
,
block_offset
,
read_lens
,
func
);
func
);
}
else
{
}
else
{
VectorizedBroadcastKernelImpl
<
InT
,
VectorizedBroadcastKernelImpl
<
InT
,
...
@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
...
@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts
,
NumOuts
,
VecSize
,
VecSize
,
Rank
,
Rank
,
true
>
(
true
>
(
ins
,
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
tail_tid
,
block_offset
,
func
);
outs
,
use_broadcast
,
numel
,
configs
,
tail_tid
,
block_offset
,
read_lens
,
func
);
}
}
#endif
#endif
}
}
...
@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
...
@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
use_broadcast
[
i
]
=
(
ins
[
i
]
->
numel
()
!=
numel
);
use_broadcast
[
i
]
=
(
ins
[
i
]
->
numel
()
!=
numel
);
ins_data
[
i
]
=
(
const
_ptr_
InT
*
)(
ins
[
i
]
->
data
<
InT
>
());
ins_data
[
i
]
=
(
const
_ptr_
InT
*
)(
ins
[
i
]
->
data
<
InT
>
());
#ifdef PADDLE_WITH_XPU_KP
if
(
i
==
0
)
{
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
0
],
merge_dims
.
in_dims
[
1
],
merge_dims
.
dim_size
);
}
else
if
(
i
==
1
)
{
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
1
],
merge_dims
.
in_dims
[
0
],
merge_dims
.
dim_size
);
}
#else
if
(
use_broadcast
[
i
])
{
if
(
use_broadcast
[
i
])
{
// get the broadcast config,
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// if data shape is[m, n], then you should set data_dim = {n, m}
...
@@ -399,20 +452,24 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
...
@@ -399,20 +452,24 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
i
],
merge_dims
.
dim_size
);
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
i
],
merge_dims
.
dim_size
);
}
}
#endif
}
}
#ifdef PADDLE_WITH_XPU_KP
#ifdef PADDLE_WITH_XPU_KP
const
int
threads
=
64
;
const
int
threads
=
64
;
const
int
blocks
=
8
;
const
int
blocks
=
8
;
int
main_offset
=
(
numel
/
(
VecSize
*
threads
))
*
VecSize
*
threads
;
int
read_lens
=
configs
[
0
].
buf_len
;
int
tail_tid
=
numel
%
(
VecSize
*
threads
);
int
main_offset
=
(
numel
/
(
read_lens
*
threads
))
*
read_lens
*
threads
;
int
tail_tid
=
numel
%
(
read_lens
*
threads
);
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
if
(
configs
[
0
].
cmp_type
!=
kps
::
details
::
OptType
::
CanNotOptimize
)
{
main_offset
=
numel
;
VectorizedBroadcastKernel
<
InT
,
VectorizedBroadcastKernel
<
InT
,
OutT
,
OutT
,
Functor
,
Functor
,
Arity
,
Arity
,
NumOuts
,
NumOuts
,
VecSize
,
512
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
outs_data
,
outs_data
,
use_broadcast
,
use_broadcast
,
...
@@ -420,7 +477,25 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
...
@@ -420,7 +477,25 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs
,
configs
,
main_offset
,
main_offset
,
tail_tid
,
tail_tid
,
read_lens
,
func
);
func
);
}
else
{
VectorizedBroadcastKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
256
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
outs_data
,
use_broadcast
,
numel
,
configs
,
main_offset
,
tail_tid
,
read_lens
,
func
);
}
#else
#else
const
int
threads
=
256
;
const
int
threads
=
256
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
...
@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
...
@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs
,
configs
,
main_offset
,
main_offset
,
tail_tid
,
tail_tid
,
VecSize
,
func
);
func
);
#endif
#endif
}
}
...
...
paddle/phi/kernels/funcs/elementwise_base.h
浏览文件 @
c7855125
...
@@ -577,14 +577,16 @@ template <typename InT,
...
@@ -577,14 +577,16 @@ template <typename InT,
struct
ElementwisePrimitiveCaller
{
struct
ElementwisePrimitiveCaller
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
);
OutT
*
result
,
int
read_lens
);
};
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
>
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
Arity
,
true
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
Arity
,
true
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseAny
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Arity
,
Functor
>
(
kps
::
ElementwiseAny
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Arity
,
Functor
>
(
result
,
args
,
func
);
result
,
args
,
func
);
}
}
...
@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
...
@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
0
,
false
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
0
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseConstant
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
func
);
kps
::
ElementwiseConstant
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
func
);
}
}
};
};
...
@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
...
@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
1
,
false
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
1
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
func
);
result
,
args
[
0
],
func
);
}
}
...
@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
...
@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
2
,
false
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
2
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
func
);
result
,
args
[
0
],
args
[
1
],
func
,
read_lens
);
}
}
};
};
...
@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
...
@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
3
,
false
>
{
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
3
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
}
}
...
@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
...
@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
}
}
};
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
,
int
NumOuts
>
struct
ElementwiseWriteDataCallerBc
{
__device__
__forceinline__
void
operator
()(
phi
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
ConditionalT
<
OutT
,
NumOuts
>
src
[
VecSize
],
int
block_offset
,
int
num
,
int
read_lens
)
{
OutT
dst
[
NumOuts
][
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
read_lens
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NumOuts
;
++
j
)
{
dst
[
j
][
i
]
=
(
src
[
i
])[
j
];
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
i
]
+
block_offset
,
dst
[
i
],
num
,
read_lens
);
}
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
>
struct
ElementwiseWriteDataCallerBc
<
OutT
,
VecSize
,
IsBoundary
,
1
>
{
__device__
__forceinline__
void
operator
()(
phi
::
Array
<
_ptr_
OutT
*
,
1
>
outs
,
OutT
src
[
VecSize
],
int
block_offset
,
int
num
,
int
read_lens
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
0
]
+
block_offset
,
src
,
num
,
read_lens
);
}
};
template
<
typename
OutT
,
template
<
typename
OutT
,
typename
Functor
,
typename
Functor
,
int
Arity
,
int
Arity
,
...
...
paddle/phi/kernels/primitive/compute_primitives.h
浏览文件 @
c7855125
...
@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
...
@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
}
}
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
,
int
read_lens
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
* are the same.
...
...
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
浏览文件 @
c7855125
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/simd_header.h"
namespace
phi
{
namespace
phi
{
namespace
kps
{
namespace
kps
{
...
@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
...
@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
}
}
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
,
int
read_lens
)
{
for
(
int
idx
=
0
;
idx
<
read_lens
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
* are the same.
...
...
paddle/phi/kernels/primitive/datamover_primitives.h
浏览文件 @
c7855125
...
@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
...
@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
}
}
}
}
template
<
typename
T
,
int
NX
>
__device__
__forceinline__
void
Init
(
T
*
dst
,
T
init_data
,
int
read_lens
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NX
;
i
++
)
{
dst
[
i
]
=
init_data
;
}
}
/**
/**
* The difference from the above function is that
* The difference from the above function is that
* it supports different data types of inputs.
* it supports different data types of inputs.
...
@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
...
@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
}
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
num
,
int
read_lens
)
{
if
(
IsBoundary
)
{
// blockDim.x * NX > num
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
dst
[
idx
]
=
src
[
thread_offset
+
idx
];
}
}
}
else
{
// blockDim,x * NX < num
constexpr
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
constexpr
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
const
VecType
*
vec_input
=
reinterpret_cast
<
const
VecType
*>
(
src
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kVectorsPerThread
;
++
i
)
{
vec_temp
[
i
]
=
vec_input
[
thread_offset
+
i
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
*
(
reinterpret_cast
<
T
*>
(
vec_temp
)
+
idx
);
}
}
}
}
/**
/**
* @brief Read 1D data from global memory to register. The difference
* @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs.
* from the above function is that it supports different data types of inputs.
...
@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst,
...
@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst,
}
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
WriteData
(
T
*
dst
,
T
*
__restrict__
src
,
int
num
,
int
read_lens
)
{
if
(
IsBoundary
)
{
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
thread_offset
+
idx
)
<
num
)
{
dst
[
thread_offset
+
idx
]
=
src
[
idx
];
}
}
}
else
{
// Vector type
constexpr
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
constexpr
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
VecType
*
vec_dst
=
reinterpret_cast
<
VecType
*>
(
dst
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
kVectorsPerThread
;
++
idx
)
{
vec_temp
[
idx
]
=
*
(
reinterpret_cast
<
VecType
*>
(
src
)
+
idx
);
vec_dst
[
thread_offset
+
idx
]
=
vec_temp
[
idx
];
}
}
}
/**
/**
* @brief Write 2D data from register to global memory according to Tx type, and
* @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type.
* store it as Ty type.
...
@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc(
...
@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc(
}
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
,
int
read_lens
)
{
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
index_src
=
0
;
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
index_output
=
thread_offset
+
nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
index_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
}
dst
[
nx
]
=
src
[
index_src
];
}
}
/**
/**
* @brief Initialize register with data index.
* @brief Initialize register with data index.
*
*
...
...
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
浏览文件 @
c7855125
此差异已折叠。
点击以展开。
paddle/phi/kernels/primitive/kernel_primitives.h
浏览文件 @
c7855125
...
@@ -46,6 +46,7 @@
...
@@ -46,6 +46,7 @@
#define KPStream gpuStream_t
#define KPStream gpuStream_t
#define KPDevice phi::GPUContext
#define KPDevice phi::GPUContext
#define _ptr_
#define _ptr_
#define __simd__
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y
#define THREAD_ID_Y threadIdx.y
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录