Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c7855125
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c7855125
编写于
5月 10, 2022
作者:
S
shixingbo
提交者:
GitHub
5月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
broadcast_add kp performance optimization (#42097)
上级
81078a88
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
880 addition
and
43 deletion
+880
-43
paddle/phi/kernels/funcs/broadcast_function.h
paddle/phi/kernels/funcs/broadcast_function.h
+107
-31
paddle/phi/kernels/funcs/elementwise_base.h
paddle/phi/kernels/funcs/elementwise_base.h
+49
-7
paddle/phi/kernels/primitive/compute_primitives.h
paddle/phi/kernels/primitive/compute_primitives.h
+14
-0
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
+14
-0
paddle/phi/kernels/primitive/datamover_primitives.h
paddle/phi/kernels/primitive/datamover_primitives.h
+104
-0
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
+591
-5
paddle/phi/kernels/primitive/kernel_primitives.h
paddle/phi/kernels/primitive/kernel_primitives.h
+1
-0
未找到文件。
paddle/phi/kernels/funcs/broadcast_function.h
浏览文件 @
c7855125
...
...
@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
}
}
template
<
typename
T
,
int
VecSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
LoadData
(
T
*
dst
,
const
_ptr_
T
*
src
,
uint32_t
block_offset
,
const
kps
::
details
::
BroadcastConfig
<
Rank
>
&
config
,
int
numel
,
int
num
,
int
need_broadcast
,
int
read_lens
)
{
// numel : whole num of output
// num: how many data will be deal with in this time
if
(
need_broadcast
)
{
kps
::
ReadDataBc
<
T
,
VecSize
,
1
,
1
,
Rank
,
IsBoundary
>
(
dst
,
src
,
block_offset
,
config
,
numel
,
read_lens
);
}
else
{
kps
::
ReadData
<
T
,
VecSize
,
1
,
1
,
IsBoundary
>
(
dst
,
src
+
block_offset
,
num
,
read_lens
);
}
}
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
...
...
@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
const
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
&
configs
,
int
num
,
int
block_offset
,
int
read_lens
,
Functor
func
)
{
InT
args
[
Arity
][
VecSize
];
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
__simd__
InT
args
[
Arity
][
VecSize
];
__simd__
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
));
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
)
,
read_lens
);
LoadData
<
InT
,
VecSize
,
Rank
,
IsBoundary
>
(
args
[
i
],
ins
[
i
],
block_offset
,
configs
[
i
],
numel
,
num
,
use_broadcast
[
i
]);
use_broadcast
[
i
],
read_lens
);
}
constexpr
bool
kCallElementwiseAny
=
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
...
...
@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
Functor
,
Arity
,
kCallElementwiseAny
>
()(
func
,
args
,
result
);
phi
::
funcs
::
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
block_offset
,
num
);
func
,
args
,
result
,
read_lens
);
phi
::
funcs
::
ElementwiseWriteDataCallerBc
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
block_offset
,
num
,
read_lens
);
}
template
<
typename
InT
,
...
...
@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
configs
,
int
main_offset
,
int
tail_tid
,
int
read_lens
,
Functor
func
)
{
int
block_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
VecSize
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
VecSize
;
int
block_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
read_lens
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
read_lens
;
#ifdef PADDLE_WITH_XPU_KP
for
(;
block_offset
<
main_offset
;
block_offset
+=
stride
)
{
...
...
@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
use_broadcast
,
numel
,
configs
,
BLOCK_NUM_X
*
VecSize
,
BLOCK_NUM_X
*
read_lens
,
block_offset
,
read_lens
,
func
);
}
int
num
=
numel
-
block_offset
;
...
...
@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts
,
VecSize
,
Rank
,
true
>
(
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
num
,
block_offset
,
func
);
true
>
(
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
num
,
block_offset
,
read_lens
,
func
);
}
#else
if
(
block_offset
<
main_offset
)
{
...
...
@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
configs
,
BLOCK_NUM_X
*
VecSize
,
block_offset
,
read_lens
,
func
);
}
else
{
VectorizedBroadcastKernelImpl
<
InT
,
...
...
@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts
,
VecSize
,
Rank
,
true
>
(
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
tail_tid
,
block_offset
,
func
);
true
>
(
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
tail_tid
,
block_offset
,
read_lens
,
func
);
}
#endif
}
...
...
@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
use_broadcast
[
i
]
=
(
ins
[
i
]
->
numel
()
!=
numel
);
ins_data
[
i
]
=
(
const
_ptr_
InT
*
)(
ins
[
i
]
->
data
<
InT
>
());
#ifdef PADDLE_WITH_XPU_KP
if
(
i
==
0
)
{
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
0
],
merge_dims
.
in_dims
[
1
],
merge_dims
.
dim_size
);
}
else
if
(
i
==
1
)
{
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
1
],
merge_dims
.
in_dims
[
0
],
merge_dims
.
dim_size
);
}
#else
if
(
use_broadcast
[
i
])
{
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
...
...
@@ -399,28 +452,50 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
i
],
merge_dims
.
dim_size
);
}
#endif
}
#ifdef PADDLE_WITH_XPU_KP
const
int
threads
=
64
;
const
int
blocks
=
8
;
int
main_offset
=
(
numel
/
(
VecSize
*
threads
))
*
VecSize
*
threads
;
int
tail_tid
=
numel
%
(
VecSize
*
threads
);
int
read_lens
=
configs
[
0
].
buf_len
;
int
main_offset
=
(
numel
/
(
read_lens
*
threads
))
*
read_lens
*
threads
;
int
tail_tid
=
numel
%
(
read_lens
*
threads
);
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
VectorizedBroadcastKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
outs_data
,
use_broadcast
,
numel
,
configs
,
main_offset
,
tail_tid
,
func
);
if
(
configs
[
0
].
cmp_type
!=
kps
::
details
::
OptType
::
CanNotOptimize
)
{
main_offset
=
numel
;
VectorizedBroadcastKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
512
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
outs_data
,
use_broadcast
,
numel
,
configs
,
main_offset
,
tail_tid
,
read_lens
,
func
);
}
else
{
VectorizedBroadcastKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
256
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
outs_data
,
use_broadcast
,
numel
,
configs
,
main_offset
,
tail_tid
,
read_lens
,
func
);
}
#else
const
int
threads
=
256
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
...
...
@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs
,
main_offset
,
tail_tid
,
VecSize
,
func
);
#endif
}
...
...
paddle/phi/kernels/funcs/elementwise_base.h
浏览文件 @
c7855125
...
...
@@ -577,14 +577,16 @@ template <typename InT,
struct
ElementwisePrimitiveCaller
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
);
OutT
*
result
,
int
read_lens
);
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
Arity
,
true
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseAny
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Arity
,
Functor
>
(
result
,
args
,
func
);
}
...
...
@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
0
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseConstant
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
func
);
}
};
...
...
@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
1
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
func
);
}
...
...
@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
2
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
func
);
result
,
args
[
0
],
args
[
1
],
func
,
read_lens
);
}
};
...
...
@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
3
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
}
...
...
@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
,
int
NumOuts
>
struct
ElementwiseWriteDataCallerBc
{
__device__
__forceinline__
void
operator
()(
phi
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
ConditionalT
<
OutT
,
NumOuts
>
src
[
VecSize
],
int
block_offset
,
int
num
,
int
read_lens
)
{
OutT
dst
[
NumOuts
][
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
read_lens
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NumOuts
;
++
j
)
{
dst
[
j
][
i
]
=
(
src
[
i
])[
j
];
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
i
]
+
block_offset
,
dst
[
i
],
num
,
read_lens
);
}
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
>
struct
ElementwiseWriteDataCallerBc
<
OutT
,
VecSize
,
IsBoundary
,
1
>
{
__device__
__forceinline__
void
operator
()(
phi
::
Array
<
_ptr_
OutT
*
,
1
>
outs
,
OutT
src
[
VecSize
],
int
block_offset
,
int
num
,
int
read_lens
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
0
]
+
block_offset
,
src
,
num
,
read_lens
);
}
};
template
<
typename
OutT
,
typename
Functor
,
int
Arity
,
...
...
paddle/phi/kernels/primitive/compute_primitives.h
浏览文件 @
c7855125
...
...
@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
,
int
read_lens
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
...
...
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
浏览文件 @
c7855125
...
...
@@ -17,6 +17,7 @@
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/simd_header.h"
namespace
phi
{
namespace
kps
{
...
...
@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
,
int
read_lens
)
{
for
(
int
idx
=
0
;
idx
<
read_lens
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
...
...
paddle/phi/kernels/primitive/datamover_primitives.h
浏览文件 @
c7855125
...
...
@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
}
}
template
<
typename
T
,
int
NX
>
__device__
__forceinline__
void
Init
(
T
*
dst
,
T
init_data
,
int
read_lens
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NX
;
i
++
)
{
dst
[
i
]
=
init_data
;
}
}
/**
* The difference from the above function is that
* it supports different data types of inputs.
...
...
@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
num
,
int
read_lens
)
{
if
(
IsBoundary
)
{
// blockDim.x * NX > num
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
dst
[
idx
]
=
src
[
thread_offset
+
idx
];
}
}
}
else
{
// blockDim,x * NX < num
constexpr
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
constexpr
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
const
VecType
*
vec_input
=
reinterpret_cast
<
const
VecType
*>
(
src
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kVectorsPerThread
;
++
i
)
{
vec_temp
[
i
]
=
vec_input
[
thread_offset
+
i
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
*
(
reinterpret_cast
<
T
*>
(
vec_temp
)
+
idx
);
}
}
}
}
/**
* @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs.
...
...
@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst,
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
WriteData
(
T
*
dst
,
T
*
__restrict__
src
,
int
num
,
int
read_lens
)
{
if
(
IsBoundary
)
{
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
thread_offset
+
idx
)
<
num
)
{
dst
[
thread_offset
+
idx
]
=
src
[
idx
];
}
}
}
else
{
// Vector type
constexpr
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
constexpr
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
VecType
*
vec_dst
=
reinterpret_cast
<
VecType
*>
(
dst
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
kVectorsPerThread
;
++
idx
)
{
vec_temp
[
idx
]
=
*
(
reinterpret_cast
<
VecType
*>
(
src
)
+
idx
);
vec_dst
[
thread_offset
+
idx
]
=
vec_temp
[
idx
];
}
}
}
/**
* @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type.
...
...
@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc(
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
,
int
read_lens
)
{
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
index_src
=
0
;
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
index_output
=
thread_offset
+
nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
index_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
}
dst
[
nx
]
=
src
[
index_src
];
}
}
/**
* @brief Initialize register with data index.
*
...
...
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
浏览文件 @
c7855125
...
...
@@ -21,6 +21,39 @@ namespace phi {
namespace
kps
{
namespace
details
{
enum
class
OptType
{
// Optimize type of calc after input shape compressed
CanNotOptimize
=
-
1
,
// can not optimize, broadcast first
N_1
,
// just like {1} op {100} or {100} op {1}
MN_N
,
// just like {100} op {3, 100} or {3, 100} op {100}
MN_M
,
// just like {3} op {3, 100} or {3, 100} op {3}
MNK_1N1
,
// just like {3} op {2, 3, 100} or {2, 3, 100} op {3}
MNK_M1K
,
// just like {2, 1, 100} op {2, 3, 100} or {2, 3, 100} op {2, 1,
// 100}
};
// Rules to determine whether dimensions can be merged
// rule 0 - xshape[idx] == yshape[idx]
// rule 1 - xshape[idx] == 1 && yshape[idx] != 1
// rule 2 - xshape[idx] != 1 && yshape[idx] == 1
static
int
judge_case
(
int
a
,
int
b
)
{
if
(
a
==
b
)
{
return
0
;
}
else
if
(
a
==
1
&&
b
!=
1
)
{
return
1
;
}
else
if
(
a
!=
1
&&
b
==
1
)
{
return
2
;
}
return
-
1
;
}
static
bool
case_is_same
(
int
case_front
,
int
case_back
)
{
if
(
case_front
==
case_back
)
{
return
true
;
}
else
{
return
false
;
}
}
template
<
typename
T
,
int
VecSize
>
struct
alignas
(
sizeof
(
T
)
*
VecSize
)
VectorType
{
T
val
[
VecSize
];
...
...
@@ -37,11 +70,20 @@ struct BroadcastConfig {
int
strides_in
[
phi
::
DDim
::
kMaxRank
];
int
strides_out
[
phi
::
DDim
::
kMaxRank
];
int
in_dim
[
phi
::
DDim
::
kMaxRank
];
int
dim_after_cmp
[
phi
::
DDim
::
kMaxRank
];
int
dim_size_after_cmp
=
0
;
int
cmp_res
=
0
;
OptType
cmp_type
=
OptType
::
CanNotOptimize
;
int
m
=
1
;
int
n
=
1
;
int
k
=
1
;
int
buf_len
=
0
;
HOSTDEVICE
BroadcastConfig
()
{}
HOSTDEVICE
BroadcastConfig
(
const
std
::
vector
<
int64_t
>&
out_dims
,
const
std
::
vector
<
int64_t
>&
in_dims
,
const
std
::
vector
<
int64_t
>&
another_in_dims
,
int
dim_size
)
{
std
::
vector
<
int
>
strides_in_tmp
;
std
::
vector
<
int
>
strides_out_tmp
;
...
...
@@ -61,18 +103,187 @@ struct BroadcastConfig {
memcpy
(
strides_in
,
strides_in_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
strides_out
,
strides_out_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
in_dim
,
dim_tmp
.
data
(),
kDims
*
sizeof
(
int
));
cmp_res
=
get_mnk_for_broadcast_ops
(
in_dims
,
another_in_dims
);
get_opt_type
(
another_in_dims
);
buf_len
=
get_buf_len
();
}
int
get_buf_len
()
{
if
(
cmp_type
==
OptType
::
CanNotOptimize
)
{
return
256
;
}
int
max_buf_len
=
512
;
int
buf_len
=
m
/
16
*
16
;
if
(
buf_len
==
0
)
{
buf_len
=
m
;
}
return
std
::
min
(
max_buf_len
,
buf_len
);
}
__device__
inline
int
operator
()(
int
index_output
)
const
{
int
index_src
=
0
;
#pragma unroll
for
(
int
i
=
kDims
-
1
;
i
>=
0
;
--
i
)
{
int
tmp_index
=
(
index_output
/
strides_out
[
i
]);
index_output
=
index_output
-
tmp_index
*
strides_out
[
i
];
index_src
+=
(
tmp_index
%
in_dim
[
i
])
*
strides_in
[
i
];
switch
(
cmp_type
)
{
int
div
,
mod
,
tmp_index
;
case
OptType
::
MNK_M1K
:
div
=
index_output
/
(
m
*
n
);
mod
=
index_output
%
(
m
*
n
)
%
m
;
index_src
=
div
*
m
+
mod
;
break
;
case
OptType
::
MNK_1N1
:
// index_src = index_output / m % n;
index_src
=
index_output
%
(
m
*
n
)
/
m
;
break
;
case
OptType
::
N_1
:
index_src
=
0
;
break
;
case
OptType
::
MN_N
:
index_src
=
index_output
/
m
;
break
;
case
OptType
::
MN_M
:
index_src
=
index_output
%
m
;
break
;
case
OptType
::
CanNotOptimize
:
for
(
int
i
=
kDims
-
1
;
i
>=
0
;
--
i
)
{
tmp_index
=
(
index_output
/
strides_out
[
i
]);
index_output
=
index_output
-
tmp_index
*
strides_out
[
i
];
index_src
+=
(
tmp_index
%
in_dim
[
i
])
*
strides_in
[
i
];
}
break
;
}
return
index_src
;
}
void
get_opt_type
(
const
std
::
vector
<
int64_t
>&
y_dim_after_cmp
)
{
if
(
dim_size_after_cmp
==
1
)
{
if
(
dim_after_cmp
[
0
]
==
1
&&
y_dim_after_cmp
[
0
]
!=
1
)
{
// {1} op {n}
n
=
y_dim_after_cmp
[
0
];
cmp_type
=
OptType
::
N_1
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
0
]
==
1
)
{
// {n} op {1}
n
=
dim_after_cmp
[
0
];
cmp_type
=
OptType
::
N_1
;
}
else
{
cmp_type
=
OptType
::
CanNotOptimize
;
// xshape == yshape
}
}
if
(
dim_size_after_cmp
==
2
)
{
if
(
dim_after_cmp
[
0
]
==
1
&&
dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
!=
1
)
{
// {n} op {m, n}
m
=
y_dim_after_cmp
[
0
];
n
=
y_dim_after_cmp
[
1
];
cmp_type
=
OptType
::
MN_N
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
==
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
!=
1
)
{
// {m} op {m, n}
m
=
y_dim_after_cmp
[
0
];
n
=
y_dim_after_cmp
[
1
];
cmp_type
=
OptType
::
MN_M
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
0
]
==
1
&&
y_dim_after_cmp
[
1
]
!=
1
)
{
// {m, n} op {n}
m
=
dim_after_cmp
[
0
];
n
=
dim_after_cmp
[
1
];
cmp_type
=
OptType
::
MN_N
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
==
1
)
{
// {m, n} op {m}
m
=
dim_after_cmp
[
0
];
n
=
dim_after_cmp
[
1
];
cmp_type
=
OptType
::
MN_M
;
}
else
{
cmp_type
=
OptType
::
CanNotOptimize
;
}
}
if
(
dim_size_after_cmp
==
3
)
{
if
(
dim_after_cmp
[
0
]
==
1
&&
dim_after_cmp
[
1
]
!=
1
&&
dim_after_cmp
[
2
]
==
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
2
]
!=
1
)
{
// {1, n, 1} op {m, n, k}
m
=
y_dim_after_cmp
[
0
];
n
=
y_dim_after_cmp
[
1
];
k
=
y_dim_after_cmp
[
2
];
cmp_type
=
OptType
::
MNK_1N1
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
!=
1
&&
dim_after_cmp
[
2
]
!=
1
&&
y_dim_after_cmp
[
0
]
==
1
&&
y_dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
2
]
==
1
)
{
// {m, n, k} op {1, n, 1}
m
=
dim_after_cmp
[
0
];
n
=
dim_after_cmp
[
1
];
k
=
dim_after_cmp
[
2
];
cmp_type
=
OptType
::
MNK_1N1
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
==
1
&&
dim_after_cmp
[
2
]
!=
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
2
]
!=
1
)
{
// {m, 1, k} op {m, n, k}
m
=
y_dim_after_cmp
[
0
];
n
=
y_dim_after_cmp
[
1
];
k
=
y_dim_after_cmp
[
2
];
cmp_type
=
OptType
::
MNK_M1K
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
!=
1
&&
dim_after_cmp
[
2
]
!=
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
==
1
&&
y_dim_after_cmp
[
2
]
!=
1
)
{
// {m, n, k} op {m, 1, k}
m
=
dim_after_cmp
[
0
];
n
=
dim_after_cmp
[
1
];
k
=
dim_after_cmp
[
2
];
cmp_type
=
OptType
::
MNK_M1K
;
}
else
{
cmp_type
=
OptType
::
CanNotOptimize
;
}
}
}
int
get_mnk_for_broadcast_ops
(
const
std
::
vector
<
int64_t
>&
xshape
,
const
std
::
vector
<
int64_t
>&
yshape
)
{
int
idx
=
0
;
int
cmp_x
=
0
;
int
cmp_y
=
0
;
bool
is_same
=
false
;
std
::
vector
<
int64_t
>
xshape_after_remove_ones
=
xshape
;
std
::
vector
<
int64_t
>
yshape_after_remove_ones
=
yshape
;
// first step: remove excess ones
std
::
vector
<
int64_t
>::
iterator
x_iter
=
xshape_after_remove_ones
.
begin
();
std
::
vector
<
int64_t
>::
iterator
y_iter
=
yshape_after_remove_ones
.
begin
();
for
(;
x_iter
!=
xshape_after_remove_ones
.
end
();)
{
if
(
*
x_iter
==
1
&&
*
y_iter
==
1
)
{
x_iter
=
xshape_after_remove_ones
.
erase
(
x_iter
);
y_iter
=
yshape_after_remove_ones
.
erase
(
y_iter
);
}
else
{
x_iter
++
;
y_iter
++
;
}
}
// second step: compress dims
int
after_cmp_idx
=
0
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
cmp_x
=
xshape_after_remove_ones
[
idx
];
cmp_y
=
yshape_after_remove_ones
[
idx
];
while
((
idx
+
1
)
<
xshape_after_remove_ones
.
size
())
{
is_same
=
case_is_same
(
judge_case
(
xshape_after_remove_ones
[
idx
],
yshape_after_remove_ones
[
idx
]),
judge_case
(
xshape_after_remove_ones
[
idx
+
1
],
yshape_after_remove_ones
[
idx
+
1
]));
if
(
is_same
)
{
cmp_x
=
cmp_x
*
xshape_after_remove_ones
[
idx
+
1
];
cmp_y
=
cmp_y
*
yshape_after_remove_ones
[
idx
+
1
];
idx
++
;
}
else
{
break
;
}
}
idx
=
idx
+
1
;
dim_after_cmp
[
after_cmp_idx
]
=
cmp_x
;
after_cmp_idx
++
;
if
(
idx
==
xshape_after_remove_ones
.
size
())
{
dim_size_after_cmp
=
after_cmp_idx
;
return
0
;
}
}
return
-
1
;
// can not compress dims
}
};
#pragma pack()
...
...
@@ -199,6 +410,14 @@ __device__ __inline__ void Init(T* dst, T init_data) {
}
}
template
<
typename
T
,
int
NX
>
__device__
__inline__
void
Init
(
T
*
dst
,
T
init_data
,
int
read_lens
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
read_lens
;
i
++
)
{
dst
[
i
]
=
init_data
;
}
}
/**
* The difference from the above function is that
* it supports different data types of inputs.
...
...
@@ -251,6 +470,26 @@ __device__ __inline__ void ReadData(T* dst,
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
>
__device__
__inline__
void
ReadData
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
num
,
int
read_lens
)
{
int
thread_offset
=
core_id
()
*
read_lens
;
__local__
T
in_temp
[
1
];
if
(
IsBoundary
)
{
// core_num() * read_lens > num
#pragma unroll
for
(
int
idx
=
0
;
idx
<
read_lens
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
GM2LM
(
src
+
thread_offset
+
idx
,
in_temp
,
sizeof
(
T
));
dst
[
idx
]
=
in_temp
[
0
];
}
}
}
else
{
// core_num() * read_lens < num
GM2LM
(
src
+
thread_offset
,
dst
,
read_lens
*
sizeof
(
T
));
}
}
/**
* @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs.
...
...
@@ -479,10 +718,32 @@ __device__ __forceinline__ void ReadDataReduce(
* size: The current block needs to load size elements continuously.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
>
__device__
void
WriteData
(
T
_global_ptr_
*
dst
,
const
T
*
src
,
int
num
,
int
read_lens
)
{
int
thread_offset
=
core_id
()
*
read_lens
;
__local__
T
in_temp
[
1
];
if
(
IsBoundary
)
{
// core_num() * read_lens > num
#pragma unroll
for
(
int
idx
=
0
;
idx
<
read_lens
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
in_temp
[
0
]
=
src
[
idx
];
LM2GM
(
in_temp
,
dst
+
idx
+
thread_offset
,
sizeof
(
T
));
}
}
}
else
{
// core_num() * read_lens < num
LM2GM
(
src
,
dst
+
thread_offset
,
read_lens
*
sizeof
(
T
));
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
>
__device__
void
WriteData
(
T
_global_ptr_
*
dst
,
const
T
*
src
,
int
num
)
{
int
thread_offset
=
core_id
()
*
NX
;
__local__
T
in_temp
[
1
];
if
(
IsBoundary
)
{
// core_num() * NX > num
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
...
...
@@ -675,6 +936,331 @@ __device__ __inline__ void ReadDataBc(
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {m, 1, k}-> {m, n, k} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBcM1kMnk
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
m
=
config
.
m
;
int
n
=
config
.
n
;
int
m_pos
=
index_base
%
m
;
if
((
m
-
m_pos
)
<
read_lens
)
{
int
last_col
=
m
-
m_pos
;
GM2LM
(
src
+
index_base
,
dst
,
last_col
*
sizeof
(
T
));
int
n_pos
=
index_output
%
(
m
*
n
)
/
m
;
int
next_part_index
=
0
;
if
(
n_pos
!=
config
.
n
-
1
)
{
next_part_index
=
index_base
/
m
*
m
;
}
else
{
next_part_index
=
(
index_base
/
m
+
1
)
*
m
;
}
GM2LM
(
src
+
next_part_index
,
dst
+
last_col
,
(
read_lens
-
last_col
)
*
sizeof
(
T
));
}
else
{
GM2LM
(
src
+
index_base
,
dst
,
read_lens
*
sizeof
(
T
));
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {m, 1}-> {m, n} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBcM1Mn
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
m
=
config
.
m
;
int
n
=
config
.
n
;
int
m_pos
=
index_base
%
m
;
if
((
m
-
m_pos
)
<
read_lens
)
{
int
last_col
=
m
-
m_pos
;
GM2LM
(
src
+
index_base
,
dst
,
last_col
*
sizeof
(
T
));
GM2LM
(
src
,
dst
+
last_col
,
(
read_lens
-
last_col
)
*
sizeof
(
T
));
}
else
{
GM2LM
(
src
+
index_base
,
dst
,
read_lens
*
sizeof
(
T
));
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {1, n}-> {m, n} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBc1NMn
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
m
=
config
.
m
;
int
n
=
config
.
n
;
T
in_temp
;
int
m_pos
=
index_output
%
m
;
if
((
m
-
m_pos
)
<
read_lens
)
{
int
last_col
=
m
-
m_pos
;
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
last_col
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
GM2LM
(
src
+
index_base
+
1
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
-
last_col
;
i
++
)
{
dst
[
last_col
+
i
]
=
in_temp
;
}
}
else
{
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {1, n, 1}-> {m, n, k} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBc1N1Mnk
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
m
=
config
.
m
;
int
n
=
config
.
n
;
T
in_temp
;
int
m_pos
=
index_output
%
m
;
if
((
m
-
m_pos
)
<
read_lens
)
{
int
last_col
=
m
-
m_pos
;
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
last_col
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
int
n_pos
=
index_output
%
(
m
*
n
)
/
m
;
int
next_part_index
=
0
;
if
(
n_pos
!=
n
-
1
)
{
next_part_index
=
n_pos
+
1
;
}
else
{
next_part_index
=
0
;
}
GM2LM
(
src
+
next_part_index
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
-
last_col
;
i
++
)
{
dst
[
last_col
+
i
]
=
in_temp
;
}
}
else
{
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {1}-> {n} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBc1N
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
T
in_temp
;
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* form which can not compress.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__inline__
void
ReadDataBcCanNotCmp
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
total_num_output
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
T
in_temp
;
int
cache_size
=
256
;
__local__
T
src_temp
[
cache_size
];
GM2LM
(
src
+
index_base
,
src_temp
,
cache_size
*
sizeof
(
T
));
for
(
int
nx
=
0
;
nx
<
read_lens
;
++
nx
)
{
index_output
=
thread_offset
+
nx
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
int
index_src
=
config
(
index_output
);
if
(
index_src
>=
index_base
&&
index_src
<
index_base
+
cache_size
)
{
in_temp
=
src_temp
[
index_src
-
index_base
];
}
else
{
GM2LM
(
src
+
index_src
,
&
in_temp
,
sizeof
(
T
));
}
dst
[
nx
]
=
in_temp
;
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* total_num_output: Total number of original output.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__inline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
total_num_output
,
int
read_lens
)
{
int
thread_offset
=
block_offset
+
core_id
()
*
read_lens
;
if
(
config
.
cmp_type
==
details
::
OptType
::
MNK_M1K
)
{
ReadDataBcM1kMnk
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
N_1
)
{
ReadDataBc1N
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MN_M
)
{
ReadDataBcM1Mn
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MN_N
)
{
ReadDataBc1NMn
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MNK_1N1
)
{
ReadDataBc1N1Mnk
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
{
ReadDataBcCanNotCmp
<
T
,
Rank
,
IsBoundary
>
(
dst
,
src
,
thread_offset
,
config
,
total_num_output
,
read_lens
);
}
}
/**
* @brief Initialize register with data index.
*
...
...
paddle/phi/kernels/primitive/kernel_primitives.h
浏览文件 @
c7855125
...
...
@@ -46,6 +46,7 @@
#define KPStream gpuStream_t
#define KPDevice phi::GPUContext
#define _ptr_
#define __simd__
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录