Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c7855125
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c7855125
编写于
5月 10, 2022
作者:
S
shixingbo
提交者:
GitHub
5月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
broadcast_add kp performance optimization (#42097)
上级
81078a88
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
880 addition
and
43 deletion
+880
-43
paddle/phi/kernels/funcs/broadcast_function.h
paddle/phi/kernels/funcs/broadcast_function.h
+107
-31
paddle/phi/kernels/funcs/elementwise_base.h
paddle/phi/kernels/funcs/elementwise_base.h
+49
-7
paddle/phi/kernels/primitive/compute_primitives.h
paddle/phi/kernels/primitive/compute_primitives.h
+14
-0
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
+14
-0
paddle/phi/kernels/primitive/datamover_primitives.h
paddle/phi/kernels/primitive/datamover_primitives.h
+104
-0
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
+591
-5
paddle/phi/kernels/primitive/kernel_primitives.h
paddle/phi/kernels/primitive/kernel_primitives.h
+1
-0
未找到文件。
paddle/phi/kernels/funcs/broadcast_function.h
浏览文件 @
c7855125
...
...
@@ -242,6 +242,27 @@ __device__ __forceinline__ void LoadData(
}
}
template
<
typename
T
,
int
VecSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
LoadData
(
T
*
dst
,
const
_ptr_
T
*
src
,
uint32_t
block_offset
,
const
kps
::
details
::
BroadcastConfig
<
Rank
>
&
config
,
int
numel
,
int
num
,
int
need_broadcast
,
int
read_lens
)
{
// numel : whole num of output
// num: how many data will be deal with in this time
if
(
need_broadcast
)
{
kps
::
ReadDataBc
<
T
,
VecSize
,
1
,
1
,
Rank
,
IsBoundary
>
(
dst
,
src
,
block_offset
,
config
,
numel
,
read_lens
);
}
else
{
kps
::
ReadData
<
T
,
VecSize
,
1
,
1
,
IsBoundary
>
(
dst
,
src
+
block_offset
,
num
,
read_lens
);
}
}
template
<
typename
InT
,
typename
OutT
,
typename
Functor
,
...
...
@@ -258,20 +279,22 @@ __device__ void VectorizedBroadcastKernelImpl(
const
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
&
configs
,
int
num
,
int
block_offset
,
int
read_lens
,
Functor
func
)
{
InT
args
[
Arity
][
VecSize
];
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
__simd__
InT
args
[
Arity
][
VecSize
];
__simd__
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
));
kps
::
Init
<
InT
,
VecSize
>
(
args
[
i
],
static_cast
<
InT
>
(
1.0
f
)
,
read_lens
);
LoadData
<
InT
,
VecSize
,
Rank
,
IsBoundary
>
(
args
[
i
],
ins
[
i
],
block_offset
,
configs
[
i
],
numel
,
num
,
use_broadcast
[
i
]);
use_broadcast
[
i
],
read_lens
);
}
constexpr
bool
kCallElementwiseAny
=
paddle
::
platform
::
FunctionTraits
<
Functor
>::
has_pointer_args
;
...
...
@@ -281,10 +304,10 @@ __device__ void VectorizedBroadcastKernelImpl(
Functor
,
Arity
,
kCallElementwiseAny
>
()(
func
,
args
,
result
);
phi
::
funcs
::
ElementwiseWriteDataCaller
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
block_offset
,
num
);
func
,
args
,
result
,
read_lens
);
phi
::
funcs
::
ElementwiseWriteDataCallerBc
<
OutT
,
VecSize
,
IsBoundary
,
NumOuts
>
()(
outs
,
result
,
block_offset
,
num
,
read_lens
);
}
template
<
typename
InT
,
...
...
@@ -302,9 +325,10 @@ __global__ void VectorizedBroadcastKernel(
phi
::
Array
<
kps
::
details
::
BroadcastConfig
<
Rank
>
,
Arity
>
configs
,
int
main_offset
,
int
tail_tid
,
int
read_lens
,
Functor
func
)
{
int
block_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
VecSize
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
VecSize
;
int
block_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
read_lens
;
int
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
read_lens
;
#ifdef PADDLE_WITH_XPU_KP
for
(;
block_offset
<
main_offset
;
block_offset
+=
stride
)
{
...
...
@@ -320,8 +344,9 @@ __global__ void VectorizedBroadcastKernel(
use_broadcast
,
numel
,
configs
,
BLOCK_NUM_X
*
VecSize
,
BLOCK_NUM_X
*
read_lens
,
block_offset
,
read_lens
,
func
);
}
int
num
=
numel
-
block_offset
;
...
...
@@ -333,8 +358,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts
,
VecSize
,
Rank
,
true
>
(
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
num
,
block_offset
,
func
);
true
>
(
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
num
,
block_offset
,
read_lens
,
func
);
}
#else
if
(
block_offset
<
main_offset
)
{
...
...
@@ -352,6 +384,7 @@ __global__ void VectorizedBroadcastKernel(
configs
,
BLOCK_NUM_X
*
VecSize
,
block_offset
,
read_lens
,
func
);
}
else
{
VectorizedBroadcastKernelImpl
<
InT
,
...
...
@@ -361,8 +394,15 @@ __global__ void VectorizedBroadcastKernel(
NumOuts
,
VecSize
,
Rank
,
true
>
(
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
tail_tid
,
block_offset
,
func
);
true
>
(
ins
,
outs
,
use_broadcast
,
numel
,
configs
,
tail_tid
,
block_offset
,
read_lens
,
func
);
}
#endif
}
...
...
@@ -392,6 +432,19 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
for
(
int
i
=
0
;
i
<
Arity
;
i
++
)
{
use_broadcast
[
i
]
=
(
ins
[
i
]
->
numel
()
!=
numel
);
ins_data
[
i
]
=
(
const
_ptr_
InT
*
)(
ins
[
i
]
->
data
<
InT
>
());
#ifdef PADDLE_WITH_XPU_KP
if
(
i
==
0
)
{
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
0
],
merge_dims
.
in_dims
[
1
],
merge_dims
.
dim_size
);
}
else
if
(
i
==
1
)
{
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
1
],
merge_dims
.
in_dims
[
0
],
merge_dims
.
dim_size
);
}
#else
if
(
use_broadcast
[
i
])
{
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
...
...
@@ -399,28 +452,50 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs
[
i
]
=
kps
::
details
::
BroadcastConfig
<
Rank
>
(
merge_dims
.
out_dims
,
merge_dims
.
in_dims
[
i
],
merge_dims
.
dim_size
);
}
#endif
}
#ifdef PADDLE_WITH_XPU_KP
const
int
threads
=
64
;
const
int
blocks
=
8
;
int
main_offset
=
(
numel
/
(
VecSize
*
threads
))
*
VecSize
*
threads
;
int
tail_tid
=
numel
%
(
VecSize
*
threads
);
int
read_lens
=
configs
[
0
].
buf_len
;
int
main_offset
=
(
numel
/
(
read_lens
*
threads
))
*
read_lens
*
threads
;
int
tail_tid
=
numel
%
(
read_lens
*
threads
);
auto
stream
=
ctx
.
x_context
()
->
xpu_stream
;
VectorizedBroadcastKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
VecSize
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
outs_data
,
use_broadcast
,
numel
,
configs
,
main_offset
,
tail_tid
,
func
);
if
(
configs
[
0
].
cmp_type
!=
kps
::
details
::
OptType
::
CanNotOptimize
)
{
main_offset
=
numel
;
VectorizedBroadcastKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
512
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
outs_data
,
use_broadcast
,
numel
,
configs
,
main_offset
,
tail_tid
,
read_lens
,
func
);
}
else
{
VectorizedBroadcastKernel
<
InT
,
OutT
,
Functor
,
Arity
,
NumOuts
,
256
,
Rank
><<<
blocks
,
threads
,
stream
>>>
(
ins_data
,
outs_data
,
use_broadcast
,
numel
,
configs
,
main_offset
,
tail_tid
,
read_lens
,
func
);
}
#else
const
int
threads
=
256
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
...
...
@@ -440,6 +515,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
configs
,
main_offset
,
tail_tid
,
VecSize
,
func
);
#endif
}
...
...
paddle/phi/kernels/funcs/elementwise_base.h
浏览文件 @
c7855125
...
...
@@ -577,14 +577,16 @@ template <typename InT,
struct
ElementwisePrimitiveCaller
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
);
OutT
*
result
,
int
read_lens
);
};
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
,
int
Arity
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
Arity
,
true
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseAny
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Arity
,
Functor
>
(
result
,
args
,
func
);
}
...
...
@@ -594,7 +596,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
0
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseConstant
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
func
);
}
};
...
...
@@ -603,7 +606,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
1
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseUnary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
func
);
}
...
...
@@ -613,9 +617,10 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
2
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseBinary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
func
);
result
,
args
[
0
],
args
[
1
],
func
,
read_lens
);
}
};
...
...
@@ -623,7 +628,8 @@ template <typename InT, typename OutT, int VecSize, typename Functor>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
3
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
OutT
*
result
,
int
read_lens
)
{
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
}
...
...
@@ -696,6 +702,42 @@ struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
,
int
NumOuts
>
struct
ElementwiseWriteDataCallerBc
{
__device__
__forceinline__
void
operator
()(
phi
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
ConditionalT
<
OutT
,
NumOuts
>
src
[
VecSize
],
int
block_offset
,
int
num
,
int
read_lens
)
{
OutT
dst
[
NumOuts
][
VecSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
read_lens
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NumOuts
;
++
j
)
{
dst
[
j
][
i
]
=
(
src
[
i
])[
j
];
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NumOuts
;
++
i
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
i
]
+
block_offset
,
dst
[
i
],
num
,
read_lens
);
}
}
};
template
<
typename
OutT
,
int
VecSize
,
bool
IsBoundary
>
struct
ElementwiseWriteDataCallerBc
<
OutT
,
VecSize
,
IsBoundary
,
1
>
{
__device__
__forceinline__
void
operator
()(
phi
::
Array
<
_ptr_
OutT
*
,
1
>
outs
,
OutT
src
[
VecSize
],
int
block_offset
,
int
num
,
int
read_lens
)
{
kps
::
WriteData
<
OutT
,
VecSize
,
1
,
1
,
IsBoundary
>
(
outs
[
0
]
+
block_offset
,
src
,
num
,
read_lens
);
}
};
template
<
typename
OutT
,
typename
Functor
,
int
Arity
,
...
...
paddle/phi/kernels/primitive/compute_primitives.h
浏览文件 @
c7855125
...
...
@@ -271,6 +271,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
,
int
read_lens
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
...
...
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
浏览文件 @
c7855125
...
...
@@ -17,6 +17,7 @@
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
#include "xpu/kernel/simd_header.h"
namespace
phi
{
namespace
kps
{
...
...
@@ -158,6 +159,19 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out,
}
}
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
OpFunc
compute
,
int
read_lens
)
{
for
(
int
idx
=
0
;
idx
<
read_lens
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
]));
}
}
/**
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
...
...
paddle/phi/kernels/primitive/datamover_primitives.h
浏览文件 @
c7855125
...
...
@@ -246,6 +246,14 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
}
}
template
<
typename
T
,
int
NX
>
__device__
__forceinline__
void
Init
(
T
*
dst
,
T
init_data
,
int
read_lens
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NX
;
i
++
)
{
dst
[
i
]
=
init_data
;
}
}
/**
* The difference from the above function is that
* it supports different data types of inputs.
...
...
@@ -311,6 +319,38 @@ __device__ __forceinline__ void ReadData(T* dst,
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
num
,
int
read_lens
)
{
if
(
IsBoundary
)
{
// blockDim.x * NX > num
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
dst
[
idx
]
=
src
[
thread_offset
+
idx
];
}
}
}
else
{
// blockDim,x * NX < num
constexpr
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
constexpr
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
const
VecType
*
vec_input
=
reinterpret_cast
<
const
VecType
*>
(
src
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kVectorsPerThread
;
++
i
)
{
vec_temp
[
i
]
=
vec_input
[
thread_offset
+
i
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
*
(
reinterpret_cast
<
T
*>
(
vec_temp
)
+
idx
);
}
}
}
}
/**
* @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs.
...
...
@@ -576,6 +616,36 @@ __device__ __forceinline__ void WriteData(T* dst,
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
WriteData
(
T
*
dst
,
T
*
__restrict__
src
,
int
num
,
int
read_lens
)
{
if
(
IsBoundary
)
{
int
thread_offset
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
thread_offset
+
idx
)
<
num
)
{
dst
[
thread_offset
+
idx
]
=
src
[
idx
];
}
}
}
else
{
// Vector type
constexpr
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
constexpr
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
thread_offset
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
VecType
*
vec_dst
=
reinterpret_cast
<
VecType
*>
(
dst
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
kVectorsPerThread
;
++
idx
)
{
vec_temp
[
idx
]
=
*
(
reinterpret_cast
<
VecType
*>
(
src
)
+
idx
);
vec_dst
[
thread_offset
+
idx
]
=
vec_temp
[
idx
];
}
}
}
/**
* @brief Write 2D data from register to global memory according to Tx type, and
* store it as Ty type.
...
...
@@ -749,6 +819,40 @@ __device__ __forceinline__ void ReadDataBc(
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
,
int
read_lens
)
{
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
index_src
=
0
;
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
index_output
=
thread_offset
+
nx
;
index_src
=
0
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
index_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
}
dst
[
nx
]
=
src
[
index_src
];
}
}
/**
* @brief Initialize register with data index.
*
...
...
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
浏览文件 @
c7855125
...
...
@@ -21,6 +21,39 @@ namespace phi {
namespace
kps
{
namespace
details
{
enum
class
OptType
{
// Optimize type of calc after input shape compressed
CanNotOptimize
=
-
1
,
// can not optimize, broadcast first
N_1
,
// just like {1} op {100} or {100} op {1}
MN_N
,
// just like {100} op {3, 100} or {3, 100} op {100}
MN_M
,
// just like {3} op {3, 100} or {3, 100} op {3}
MNK_1N1
,
// just like {3} op {2, 3, 100} or {2, 3, 100} op {3}
MNK_M1K
,
// just like {2, 1, 100} op {2, 3, 100} or {2, 3, 100} op {2, 1,
// 100}
};
// Rules to determine whether dimensions can be merged
// rule 0 - xshape[idx] == yshape[idx]
// rule 1 - xshape[idx] == 1 && yshape[idx] != 1
// rule 2 - xshape[idx] != 1 && yshape[idx] == 1
static
int
judge_case
(
int
a
,
int
b
)
{
if
(
a
==
b
)
{
return
0
;
}
else
if
(
a
==
1
&&
b
!=
1
)
{
return
1
;
}
else
if
(
a
!=
1
&&
b
==
1
)
{
return
2
;
}
return
-
1
;
}
static
bool
case_is_same
(
int
case_front
,
int
case_back
)
{
if
(
case_front
==
case_back
)
{
return
true
;
}
else
{
return
false
;
}
}
template
<
typename
T
,
int
VecSize
>
struct
alignas
(
sizeof
(
T
)
*
VecSize
)
VectorType
{
T
val
[
VecSize
];
...
...
@@ -37,11 +70,20 @@ struct BroadcastConfig {
int
strides_in
[
phi
::
DDim
::
kMaxRank
];
int
strides_out
[
phi
::
DDim
::
kMaxRank
];
int
in_dim
[
phi
::
DDim
::
kMaxRank
];
int
dim_after_cmp
[
phi
::
DDim
::
kMaxRank
];
int
dim_size_after_cmp
=
0
;
int
cmp_res
=
0
;
OptType
cmp_type
=
OptType
::
CanNotOptimize
;
int
m
=
1
;
int
n
=
1
;
int
k
=
1
;
int
buf_len
=
0
;
HOSTDEVICE
BroadcastConfig
()
{}
HOSTDEVICE
BroadcastConfig
(
const
std
::
vector
<
int64_t
>&
out_dims
,
const
std
::
vector
<
int64_t
>&
in_dims
,
const
std
::
vector
<
int64_t
>&
another_in_dims
,
int
dim_size
)
{
std
::
vector
<
int
>
strides_in_tmp
;
std
::
vector
<
int
>
strides_out_tmp
;
...
...
@@ -61,18 +103,187 @@ struct BroadcastConfig {
memcpy
(
strides_in
,
strides_in_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
strides_out
,
strides_out_tmp
.
data
(),
kDims
*
sizeof
(
int
));
memcpy
(
in_dim
,
dim_tmp
.
data
(),
kDims
*
sizeof
(
int
));
cmp_res
=
get_mnk_for_broadcast_ops
(
in_dims
,
another_in_dims
);
get_opt_type
(
another_in_dims
);
buf_len
=
get_buf_len
();
}
int
get_buf_len
()
{
if
(
cmp_type
==
OptType
::
CanNotOptimize
)
{
return
256
;
}
int
max_buf_len
=
512
;
int
buf_len
=
m
/
16
*
16
;
if
(
buf_len
==
0
)
{
buf_len
=
m
;
}
return
std
::
min
(
max_buf_len
,
buf_len
);
}
__device__
inline
int
operator
()(
int
index_output
)
const
{
int
index_src
=
0
;
#pragma unroll
for
(
int
i
=
kDims
-
1
;
i
>=
0
;
--
i
)
{
int
tmp_index
=
(
index_output
/
strides_out
[
i
]);
index_output
=
index_output
-
tmp_index
*
strides_out
[
i
];
index_src
+=
(
tmp_index
%
in_dim
[
i
])
*
strides_in
[
i
];
switch
(
cmp_type
)
{
int
div
,
mod
,
tmp_index
;
case
OptType
::
MNK_M1K
:
div
=
index_output
/
(
m
*
n
);
mod
=
index_output
%
(
m
*
n
)
%
m
;
index_src
=
div
*
m
+
mod
;
break
;
case
OptType
::
MNK_1N1
:
// index_src = index_output / m % n;
index_src
=
index_output
%
(
m
*
n
)
/
m
;
break
;
case
OptType
::
N_1
:
index_src
=
0
;
break
;
case
OptType
::
MN_N
:
index_src
=
index_output
/
m
;
break
;
case
OptType
::
MN_M
:
index_src
=
index_output
%
m
;
break
;
case
OptType
::
CanNotOptimize
:
for
(
int
i
=
kDims
-
1
;
i
>=
0
;
--
i
)
{
tmp_index
=
(
index_output
/
strides_out
[
i
]);
index_output
=
index_output
-
tmp_index
*
strides_out
[
i
];
index_src
+=
(
tmp_index
%
in_dim
[
i
])
*
strides_in
[
i
];
}
break
;
}
return
index_src
;
}
void
get_opt_type
(
const
std
::
vector
<
int64_t
>&
y_dim_after_cmp
)
{
if
(
dim_size_after_cmp
==
1
)
{
if
(
dim_after_cmp
[
0
]
==
1
&&
y_dim_after_cmp
[
0
]
!=
1
)
{
// {1} op {n}
n
=
y_dim_after_cmp
[
0
];
cmp_type
=
OptType
::
N_1
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
0
]
==
1
)
{
// {n} op {1}
n
=
dim_after_cmp
[
0
];
cmp_type
=
OptType
::
N_1
;
}
else
{
cmp_type
=
OptType
::
CanNotOptimize
;
// xshape == yshape
}
}
if
(
dim_size_after_cmp
==
2
)
{
if
(
dim_after_cmp
[
0
]
==
1
&&
dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
!=
1
)
{
// {n} op {m, n}
m
=
y_dim_after_cmp
[
0
];
n
=
y_dim_after_cmp
[
1
];
cmp_type
=
OptType
::
MN_N
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
==
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
!=
1
)
{
// {m} op {m, n}
m
=
y_dim_after_cmp
[
0
];
n
=
y_dim_after_cmp
[
1
];
cmp_type
=
OptType
::
MN_M
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
0
]
==
1
&&
y_dim_after_cmp
[
1
]
!=
1
)
{
// {m, n} op {n}
m
=
dim_after_cmp
[
0
];
n
=
dim_after_cmp
[
1
];
cmp_type
=
OptType
::
MN_N
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
==
1
)
{
// {m, n} op {m}
m
=
dim_after_cmp
[
0
];
n
=
dim_after_cmp
[
1
];
cmp_type
=
OptType
::
MN_M
;
}
else
{
cmp_type
=
OptType
::
CanNotOptimize
;
}
}
if
(
dim_size_after_cmp
==
3
)
{
if
(
dim_after_cmp
[
0
]
==
1
&&
dim_after_cmp
[
1
]
!=
1
&&
dim_after_cmp
[
2
]
==
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
2
]
!=
1
)
{
// {1, n, 1} op {m, n, k}
m
=
y_dim_after_cmp
[
0
];
n
=
y_dim_after_cmp
[
1
];
k
=
y_dim_after_cmp
[
2
];
cmp_type
=
OptType
::
MNK_1N1
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
!=
1
&&
dim_after_cmp
[
2
]
!=
1
&&
y_dim_after_cmp
[
0
]
==
1
&&
y_dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
2
]
==
1
)
{
// {m, n, k} op {1, n, 1}
m
=
dim_after_cmp
[
0
];
n
=
dim_after_cmp
[
1
];
k
=
dim_after_cmp
[
2
];
cmp_type
=
OptType
::
MNK_1N1
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
==
1
&&
dim_after_cmp
[
2
]
!=
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
!=
1
&&
y_dim_after_cmp
[
2
]
!=
1
)
{
// {m, 1, k} op {m, n, k}
m
=
y_dim_after_cmp
[
0
];
n
=
y_dim_after_cmp
[
1
];
k
=
y_dim_after_cmp
[
2
];
cmp_type
=
OptType
::
MNK_M1K
;
}
else
if
(
dim_after_cmp
[
0
]
!=
1
&&
dim_after_cmp
[
1
]
!=
1
&&
dim_after_cmp
[
2
]
!=
1
&&
y_dim_after_cmp
[
0
]
!=
1
&&
y_dim_after_cmp
[
1
]
==
1
&&
y_dim_after_cmp
[
2
]
!=
1
)
{
// {m, n, k} op {m, 1, k}
m
=
dim_after_cmp
[
0
];
n
=
dim_after_cmp
[
1
];
k
=
dim_after_cmp
[
2
];
cmp_type
=
OptType
::
MNK_M1K
;
}
else
{
cmp_type
=
OptType
::
CanNotOptimize
;
}
}
}
int
get_mnk_for_broadcast_ops
(
const
std
::
vector
<
int64_t
>&
xshape
,
const
std
::
vector
<
int64_t
>&
yshape
)
{
int
idx
=
0
;
int
cmp_x
=
0
;
int
cmp_y
=
0
;
bool
is_same
=
false
;
std
::
vector
<
int64_t
>
xshape_after_remove_ones
=
xshape
;
std
::
vector
<
int64_t
>
yshape_after_remove_ones
=
yshape
;
// first step: remove excess ones
std
::
vector
<
int64_t
>::
iterator
x_iter
=
xshape_after_remove_ones
.
begin
();
std
::
vector
<
int64_t
>::
iterator
y_iter
=
yshape_after_remove_ones
.
begin
();
for
(;
x_iter
!=
xshape_after_remove_ones
.
end
();)
{
if
(
*
x_iter
==
1
&&
*
y_iter
==
1
)
{
x_iter
=
xshape_after_remove_ones
.
erase
(
x_iter
);
y_iter
=
yshape_after_remove_ones
.
erase
(
y_iter
);
}
else
{
x_iter
++
;
y_iter
++
;
}
}
// second step: compress dims
int
after_cmp_idx
=
0
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
cmp_x
=
xshape_after_remove_ones
[
idx
];
cmp_y
=
yshape_after_remove_ones
[
idx
];
while
((
idx
+
1
)
<
xshape_after_remove_ones
.
size
())
{
is_same
=
case_is_same
(
judge_case
(
xshape_after_remove_ones
[
idx
],
yshape_after_remove_ones
[
idx
]),
judge_case
(
xshape_after_remove_ones
[
idx
+
1
],
yshape_after_remove_ones
[
idx
+
1
]));
if
(
is_same
)
{
cmp_x
=
cmp_x
*
xshape_after_remove_ones
[
idx
+
1
];
cmp_y
=
cmp_y
*
yshape_after_remove_ones
[
idx
+
1
];
idx
++
;
}
else
{
break
;
}
}
idx
=
idx
+
1
;
dim_after_cmp
[
after_cmp_idx
]
=
cmp_x
;
after_cmp_idx
++
;
if
(
idx
==
xshape_after_remove_ones
.
size
())
{
dim_size_after_cmp
=
after_cmp_idx
;
return
0
;
}
}
return
-
1
;
// can not compress dims
}
};
#pragma pack()
...
...
@@ -199,6 +410,14 @@ __device__ __inline__ void Init(T* dst, T init_data) {
}
}
template
<
typename
T
,
int
NX
>
__device__
__inline__
void
Init
(
T
*
dst
,
T
init_data
,
int
read_lens
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
read_lens
;
i
++
)
{
dst
[
i
]
=
init_data
;
}
}
/**
* The difference from the above function is that
* it supports different data types of inputs.
...
...
@@ -251,6 +470,26 @@ __device__ __inline__ void ReadData(T* dst,
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
>
__device__
__inline__
void
ReadData
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
num
,
int
read_lens
)
{
int
thread_offset
=
core_id
()
*
read_lens
;
__local__
T
in_temp
[
1
];
if
(
IsBoundary
)
{
// core_num() * read_lens > num
#pragma unroll
for
(
int
idx
=
0
;
idx
<
read_lens
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
GM2LM
(
src
+
thread_offset
+
idx
,
in_temp
,
sizeof
(
T
));
dst
[
idx
]
=
in_temp
[
0
];
}
}
}
else
{
// core_num() * read_lens < num
GM2LM
(
src
+
thread_offset
,
dst
,
read_lens
*
sizeof
(
T
));
}
}
/**
* @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs.
...
...
@@ -479,10 +718,32 @@ __device__ __forceinline__ void ReadDataReduce(
* size: The current block needs to load size elements continuously.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
>
__device__
void
WriteData
(
T
_global_ptr_
*
dst
,
const
T
*
src
,
int
num
,
int
read_lens
)
{
int
thread_offset
=
core_id
()
*
read_lens
;
__local__
T
in_temp
[
1
];
if
(
IsBoundary
)
{
// core_num() * read_lens > num
#pragma unroll
for
(
int
idx
=
0
;
idx
<
read_lens
;
++
idx
)
{
if
(
idx
+
thread_offset
<
num
)
{
in_temp
[
0
]
=
src
[
idx
];
LM2GM
(
in_temp
,
dst
+
idx
+
thread_offset
,
sizeof
(
T
));
}
}
}
else
{
// core_num() * read_lens < num
LM2GM
(
src
,
dst
+
thread_offset
,
read_lens
*
sizeof
(
T
));
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
>
__device__
void
WriteData
(
T
_global_ptr_
*
dst
,
const
T
*
src
,
int
num
)
{
int
thread_offset
=
core_id
()
*
NX
;
__local__
T
in_temp
[
1
];
if
(
IsBoundary
)
{
// core_num() * NX > num
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
...
...
@@ -675,6 +936,331 @@ __device__ __inline__ void ReadDataBc(
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {m, 1, k}-> {m, n, k} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBcM1kMnk
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
m
=
config
.
m
;
int
n
=
config
.
n
;
int
m_pos
=
index_base
%
m
;
if
((
m
-
m_pos
)
<
read_lens
)
{
int
last_col
=
m
-
m_pos
;
GM2LM
(
src
+
index_base
,
dst
,
last_col
*
sizeof
(
T
));
int
n_pos
=
index_output
%
(
m
*
n
)
/
m
;
int
next_part_index
=
0
;
if
(
n_pos
!=
config
.
n
-
1
)
{
next_part_index
=
index_base
/
m
*
m
;
}
else
{
next_part_index
=
(
index_base
/
m
+
1
)
*
m
;
}
GM2LM
(
src
+
next_part_index
,
dst
+
last_col
,
(
read_lens
-
last_col
)
*
sizeof
(
T
));
}
else
{
GM2LM
(
src
+
index_base
,
dst
,
read_lens
*
sizeof
(
T
));
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {m, 1}-> {m, n} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBcM1Mn
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
m
=
config
.
m
;
int
n
=
config
.
n
;
int
m_pos
=
index_base
%
m
;
if
((
m
-
m_pos
)
<
read_lens
)
{
int
last_col
=
m
-
m_pos
;
GM2LM
(
src
+
index_base
,
dst
,
last_col
*
sizeof
(
T
));
GM2LM
(
src
,
dst
+
last_col
,
(
read_lens
-
last_col
)
*
sizeof
(
T
));
}
else
{
GM2LM
(
src
+
index_base
,
dst
,
read_lens
*
sizeof
(
T
));
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {1, n}-> {m, n} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBc1NMn
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
m
=
config
.
m
;
int
n
=
config
.
n
;
T
in_temp
;
int
m_pos
=
index_output
%
m
;
if
((
m
-
m_pos
)
<
read_lens
)
{
int
last_col
=
m
-
m_pos
;
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
last_col
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
GM2LM
(
src
+
index_base
+
1
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
-
last_col
;
i
++
)
{
dst
[
last_col
+
i
]
=
in_temp
;
}
}
else
{
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {1, n, 1}-> {m, n, k} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBc1N1Mnk
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
int
m
=
config
.
m
;
int
n
=
config
.
n
;
T
in_temp
;
int
m_pos
=
index_output
%
m
;
if
((
m
-
m_pos
)
<
read_lens
)
{
int
last_col
=
m
-
m_pos
;
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
last_col
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
int
n_pos
=
index_output
%
(
m
*
n
)
/
m
;
int
next_part_index
=
0
;
if
(
n_pos
!=
n
-
1
)
{
next_part_index
=
n_pos
+
1
;
}
else
{
next_part_index
=
0
;
}
GM2LM
(
src
+
next_part_index
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
-
last_col
;
i
++
)
{
dst
[
last_col
+
i
]
=
in_temp
;
}
}
else
{
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* {1}-> {n} form.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
>
__device__
__inline__
void
ReadDataBc1N
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
T
in_temp
;
GM2LM
(
src
+
index_base
,
&
in_temp
,
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
read_lens
;
i
++
)
{
dst
[
i
]
=
in_temp
;
}
}
/**
* @brief Read data from global memory to local memory with broadcast
* form which can not compress.
*
* @template paraments
* T: Data type of register.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* src: The original input data pointer of kernel.
* thread_offset: The data offset of this thread.
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* total_num_output: Total number of original output.
* read_lens: The number of data continuously loaded by each thread.
*/
template
<
typename
T
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__inline__
void
ReadDataBcCanNotCmp
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
int
thread_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
total_num_output
,
int
read_lens
)
{
int
index_output
=
thread_offset
;
int
index_base
=
config
(
index_output
);
T
in_temp
;
int
cache_size
=
256
;
__local__
T
src_temp
[
cache_size
];
GM2LM
(
src
+
index_base
,
src_temp
,
cache_size
*
sizeof
(
T
));
for
(
int
nx
=
0
;
nx
<
read_lens
;
++
nx
)
{
index_output
=
thread_offset
+
nx
;
if
(
IsBoundary
)
{
if
(
index_output
>=
total_num_output
)
{
break
;
}
}
int
index_src
=
config
(
index_output
);
if
(
index_src
>=
index_base
&&
index_src
<
index_base
+
cache_size
)
{
in_temp
=
src_temp
[
index_src
-
index_base
];
}
else
{
GM2LM
(
src
+
index_src
,
&
in_temp
,
sizeof
(
T
));
}
dst
[
nx
]
=
in_temp
;
}
}
/**
* @brief Read 1D data from global memory to register with broadcast form.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x core_num(), boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The original input data pointer of kernel.
* block_offset: The data offset of this block, core_num() * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data.
* read_lens: The number of data continuously loaded by each thread.
* total_num_output: Total number of original output.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__inline__
void
ReadDataBc
(
T
*
dst
,
const
T
_global_ptr_
*
src
,
uint32_t
block_offset
,
const
details
::
BroadcastConfig
<
Rank
>&
config
,
int
total_num_output
,
int
read_lens
)
{
int
thread_offset
=
block_offset
+
core_id
()
*
read_lens
;
if
(
config
.
cmp_type
==
details
::
OptType
::
MNK_M1K
)
{
ReadDataBcM1kMnk
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
N_1
)
{
ReadDataBc1N
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MN_M
)
{
ReadDataBcM1Mn
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MN_N
)
{
ReadDataBc1NMn
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
if
(
config
.
cmp_type
==
details
::
OptType
::
MNK_1N1
)
{
ReadDataBc1N1Mnk
<
T
,
Rank
>
(
dst
,
src
,
thread_offset
,
config
,
read_lens
);
}
else
{
ReadDataBcCanNotCmp
<
T
,
Rank
,
IsBoundary
>
(
dst
,
src
,
thread_offset
,
config
,
total_num_output
,
read_lens
);
}
}
/**
* @brief Initialize register with data index.
*
...
...
paddle/phi/kernels/primitive/kernel_primitives.h
浏览文件 @
c7855125
...
...
@@ -46,6 +46,7 @@
#define KPStream gpuStream_t
#define KPDevice phi::GPUContext
#define _ptr_
#define __simd__
#define THREAD_ID_X threadIdx.x
#define THREAD_ID_Y threadIdx.y
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录