Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bc0df48b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bc0df48b
编写于
9月 26, 2021
作者:
N
niuliling123
提交者:
GitHub
9月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[icafe-31094] Add function comments and instructions to the Primitive API (#35743)
上级
372a1a75
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
299 addition
and
201 deletion
+299
-201
paddle/fluid/operators/kernel_primitives/compute_primitives.h
...le/fluid/operators/kernel_primitives/compute_primitives.h
+155
-87
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
.../fluid/operators/kernel_primitives/datamover_primitives.h
+144
-114
未找到文件。
paddle/fluid/operators/kernel_primitives/compute_primitives.h
浏览文件 @
bc0df48b
...
...
@@ -54,8 +54,8 @@ class MPTypeTrait<platform::float16> {
};
/**
* @brief
w
ill be used in BlockYReduce, get the index of reduce_num in shared
* memory
* @brief
W
ill be used in BlockYReduce, get the index of reduce_num in shared
* memory
.
*/
__device__
__forceinline__
int
SharedMemoryIndex
(
int
index
)
{
return
(
threadIdx
.
y
+
index
)
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -83,7 +83,7 @@ __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
*/
/**
* @brief BlockXReduce reduce along blockDim.x
* @brief BlockXReduce reduce along blockDim.x
.
*/
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
BlockXReduce
(
T
val
,
ReduceOp
reducer
)
{
...
...
@@ -115,7 +115,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
}
/**
* @brief BlockYReduce reduce along blockDim.y
* @brief BlockYReduce reduce along blockDim.y
.
*/
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
BlockYReduce
(
T
val
,
ReduceOp
reducer
)
{
...
...
@@ -135,24 +135,33 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
}
// namespace details
/**
* @brief unary function
* @param
* T: data type of in
* OutT: data type of out
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* @brief Perform unary calculation according to OpFunc. Size of input and
* output are the same.
*
* @template paraments
* InT: Data type of in.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a) const {
* HOSTDEVICE OutT operator()(const
In
T& a) const {
* return ...;
* }
* };
*
* @param:
* out: The register pointer of out, the size is NX * NY.
* in: The register pointer of in, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
template
<
typename
In
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseUnary
(
OutT
*
out
,
const
T
*
in
,
__device__
__forceinline__
void
ElementwiseUnary
(
OutT
*
out
,
const
In
T
*
in
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
idx
++
)
{
...
...
@@ -161,25 +170,35 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
}
/**
* @brief binary function, in1 and in2 have same shape
* @param
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* @brief Binary calculation according to OpFunc. Size of The input and output
* are the same.
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const
T& a, const
T& b) const {
* HOSTDEVICE OutT operator()(const
InT& a, const In
T& b) const {
* return ...;
* }
* };
*
* @param:
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
template
<
typename
In
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
T
*
in1
,
const
T
*
in2
,
__device__
__forceinline__
void
ElementwiseBinary
(
OutT
*
out
,
const
In
T
*
in1
,
const
In
T
*
in2
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
...
...
@@ -188,25 +207,38 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
}
/**
* @brief ternary function, in1, in2 and in3 have same shape
* @param
* T: data type of in1, in2, in3
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* @brief Ternary calculation according to OpFunc. Size of input and output
* are the same.
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a, const T& b, const T& c) const {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b, const InT& c)
* const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* in3: The register pointer of third input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
template
<
typename
In
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseTernary
(
OutT
*
out
,
const
T
*
in1
,
const
T
*
in2
,
const
T
*
in3
,
__device__
__forceinline__
void
ElementwiseTernary
(
OutT
*
out
,
const
InT
*
in1
,
const
InT
*
in2
,
const
InT
*
in3
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
...
...
@@ -215,27 +247,36 @@ __device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
}
/**
* @brief a general function for elementwise computation, all inputs have
* the same shape.
* @param
* T: data type of in1, in2, in3
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* @brief Multivariate calculation according to OpFunc. Size of input and output
* are the same.
*
* @template paraments
* InT: Data type of in1, in2 and in3.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* Arity: The size of ins
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T* args) const {
* HOSTDEVICE OutT operator()(const
In
T* args) const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* ins: An array of pointers consisting of multiple inputs.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Arity
,
template
<
typename
In
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Arity
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseAny
(
OutT
*
out
,
T
(
*
ins
)[
NX
*
NY
],
__device__
__forceinline__
void
ElementwiseAny
(
OutT
*
out
,
In
T
(
*
ins
)[
NX
*
NY
],
OpFunc
compute
)
{
T
args
[
Arity
];
In
T
args
[
Arity
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
#pragma unroll
...
...
@@ -247,20 +288,36 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, T (*ins)[NX * NY],
}
/**
* @brief cycle binary function, in1's shape size is [1, NX], in2's shape size
* is [NY, NX], out's shape size is [NY, NX]
* @brief Binary calculation according to OpFunc. Shape of in1 and in2 are the
* different. Shape of in1 is [1, NX], but in2's shape is [NY, NX], the output
* shape is [NY, NX].
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
*
* @param
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor eg: in1 + in2, in1 - in2
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * 1.
* in2: The register pointer of second input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
template
<
typename
In
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
CycleBinary
(
OutT
*
out
,
const
T
*
in1
,
const
T
*
in2
,
OpFunc
compute
)
{
__device__
__forceinline__
void
CycleBinary
(
OutT
*
out
,
const
In
T
*
in1
,
const
In
T
*
in2
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
idx
++
)
{
#pragma unroll
...
...
@@ -272,26 +329,37 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
}
/**
* @brief reduce function, in's shape size is [NX, NY].
* If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1],
* if ReduceMode == kGlobalMode then reduce between different threads, the
* shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was
* split, BlockYReduce will be called. If reduce_last_dim is true and
* reduce_num was split, BlockXReduce will be called
* @typename
* T: data type of in
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: reduce functor, eg: CustomSum, CustomMean in reduce_functor_op.h
* @param:
* reducer: reduce functor, eg: CustomSum<T>()
* reduce_last_dim: if in's last dim need to be reduce then reduce_last_dim =
* true
* @brief The Reduce provides collective methods for computing a parallel
* reduction of items partitioned across a CUDA block and intra thread. When
* ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
* kGlobalMode, use shared memory to reduce between threads.
*
* @template paraments
* T: The type of data.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* ReduceFunctor: Compute functor which has an operator() as following
* template <typename InT>
* struct ReduceFunctor {
* HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
* ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in: The register pointer of in, the size is NX * NY.
* reducer: Compute function which was declared like ReduceFunctor<InT>().
* reduce_last_dim: if the last dim gets involved in reduction.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
,
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
class
ReduceFunctor
,
details
::
ReduceMode
Mode
>
__device__
__forceinline__
void
Reduce
(
T
*
out
,
const
T
*
in
,
OpFunc
reducer
,
__device__
__forceinline__
void
Reduce
(
T
*
out
,
const
T
*
in
,
ReduceFunctor
reducer
,
bool
reduce_last_dim
)
{
int
block_index
=
blockDim
.
y
;
...
...
@@ -302,7 +370,7 @@ __device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
if
(
block_reduce_y
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
*
NX
;
i
++
)
{
// reduce along blockdim.y
out
[
i
]
=
details
::
BlockYReduce
<
T
,
OpFunc
>
(
out
[
i
],
reducer
);
out
[
i
]
=
details
::
BlockYReduce
<
T
,
ReduceFunctor
>
(
out
[
i
],
reducer
);
}
}
...
...
@@ -310,7 +378,7 @@ __device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
if
(
reduce_last_dim
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
*
NX
;
i
++
)
{
// reduce along blockDim.x
out
[
i
]
=
details
::
BlockXReduce
<
T
,
OpFunc
>
(
out
[
i
],
reducer
);
out
[
i
]
=
details
::
BlockXReduce
<
T
,
ReduceFunctor
>
(
out
[
i
],
reducer
);
}
}
}
else
{
// else kLocalMode
...
...
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
浏览文件 @
bc0df48b
...
...
@@ -32,7 +32,13 @@ template <typename T, int VecSize>
struct
alignas
(
sizeof
(
T
)
*
VecSize
)
VectorType
{
T
val
[
VecSize
];
};
/**
* Fast division : Replace division in CUDA with multiplication to improve
* kernel performance.
* 1. Complete the division calculation on the CPU, and record the calculation
* results by using the divider and shift_val.
* 2. Set the divisor on the GPU through Div() to complete the calculation.
*/
struct
FastDivMod
{
// 1st value represents the result of input number divides by recorded divisor
// 2nd value represents the result of input number modulo by recorded divisor
...
...
@@ -71,6 +77,11 @@ struct FastDivMod {
uint32_t
multiplier
;
};
/**
* Configuration of broadcast. Calculate the input data index according to the
* index of the output data. if input or output shape is [dim0, dim1] then dims
* must be [dim1, dim0].
*/
template
<
int
kDims
>
struct
BroadcastConfig
{
FastDivMod
divmoders
[
kDims
];
...
...
@@ -107,65 +118,31 @@ struct BroadcastConfig {
}
// namespace details
/**
* @brief load data from src to dst, src can be 1D data or 2D data. Note that
* you can use this function when you are sure that the data will not cross the
* boundary.
* @typename:
* Tx: data type of src
* Ty: data type of dstt
* NX: the cols of src, dst
* NY: the rows of src, dst
* BlockSize: the config of this device
* @param:
* stride_nx: the stride of cols
* stride_ny: the stride of rows
*/
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
ReadData
(
Ty
*
dst
,
const
Tx
*
__restrict__
src
,
int
stride_nx
,
int
stride_ny
)
{
int
thread_offset
=
threadIdx
.
x
*
NX
;
if
(
NY
==
1
&&
NX
==
1
)
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
]);
}
else
if
(
NX
==
1
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idy
*
stride_ny
]);
}
}
else
if
(
NY
==
1
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idx
*
stride_nx
]);
}
}
else
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
dst
[
idy
*
NX
+
idx
]
=
static_cast
<
Ty
>
(
src
[
thread_offset
+
idx
*
stride_nx
+
idy
*
stride_ny
]);
}
}
}
}
/**
* @brief load data from src to dst with stride, src can be 1D data or 2D data.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
* @typename:
* Tx: data type of src
* Ty: data type of dstt
* NX: the cols of src, dst
* NY: the rows of src, dst
* BlockSize: the config of this device
* IsBoundary: whether to make boundary judgment
* @brief Read 2D data from global memory to registers according to Tx type, and
* store it as Ty type.
*
* @template paraments
* Tx: The type of data stored in the global memory.
* Ty: The type of data that needs to be stored in registers.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x blockDim, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* size_nx: number of columns to be processed by the current block
* size_ny: number of rows to be processed by the current block
* stride_nx: the stride of cols
* stride_ny: the stride of rows
* dst: The register pointer of the thread, the size is NX * NY.
* src: Data pointer of the current block.
* size_nx: The current block needs to load size_nx columns of data, this
* parameter will be used when IsBoundary = true.
* size_ny: The current block needs to load size_ny rows of data. This parameter
* will be used when IsBoundary = true.
* stride_nx: The stride of cols.
* stride_ny: The stride of rows.
*/
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
...
...
@@ -226,6 +203,17 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
}
}
/**
* @brief Initialize register with init_data.
*
* @template paraments
* T: Data type of register.
* NX: Number of data to initialize.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* init_data: Initial value.
*/
template
<
typename
T
,
int
NX
>
__device__
__forceinline__
void
Init
(
T
*
dst
,
T
init_data
)
{
#pragma unroll
...
...
@@ -234,18 +222,27 @@ __device__ __forceinline__ void Init(T* dst, T init_data) {
}
}
/** @brief: ReadData
* @brief load data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* IsBoundary: whether to make boundary judgment
/**
* @brief Read 2D data from global memory to registers. When IsBoundary = true
* and (NX % 4 == 0 or Nx % 2 == 0), vectorized load data will be used to
* improve memory access efficiency.
*
* @template paraments
* T: Data type of src and dst.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* num: number of columns to be processed by the current block
* dst: The register pointer of the thread, the size is NX * NY.
* src: Data pointer of the current block.
* size: The current block needs to load size data continuously.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
...
...
@@ -279,28 +276,38 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
}
/**
* @brief: read data for broadcast
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size
* is 2
* IsBoundary: whether to make boundary judgment
* @brief Read 2D data from global memory to registers for broadcast.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x blockDim, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* block_offset: data offset of this block, blockDim.x * blockIdx.x * NX;
* config: get the global index in src, attention config was declared in host;
* total_num_output: total num of output
* stride_nx: the stride of cols
* stride_ny: the stride of rows
* dst: The register pointer of the thread, the size is NX * NY.
* src: Raw input data pointer of kernel.
* block_offset: Data offset of this block, blockDim.x * blockIdx.x * NX;
* config: Calculation configuration of broadcast. It is used to calculate the
* coordinate mapping relationship between output data and input data. Please
* refer to the sample code for specific usage.
* total_num_output: Total number of original output.
* stride_nx: The stride of cols.
* stride_ny: The stride of rows.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
,
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
block_offset
,
details
::
BroadcastConfig
<
ShapeSize
>
config
,
int
total_num_output
,
int
stride_n
x
,
int
stride_n
y
)
{
details
::
BroadcastConfig
<
Rank
>
config
,
int
total_num_output
,
int
stride_nx
,
int
stride_ny
)
{
uint32_t
thread_offset
=
block_offset
+
threadIdx
.
x
*
NX
;
uint32_t
index_src
=
0
;
...
...
@@ -316,7 +323,7 @@ __device__ __forceinline__ void ReadDataBc(
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
ShapeSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
auto
fast_divmoder
=
config
.
divmoders
[
i
].
Divmod
(
index_output
);
index_output
=
fast_divmoder
.
val
[
0
];
index_src
+=
fast_divmoder
.
val
[
1
]
*
config
.
strides
[
i
];
...
...
@@ -327,27 +334,41 @@ __device__ __forceinline__ void ReadDataBc(
}
/**
* @brief: read data for broadcast
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size
* is 2
* IndexCal: get the global index in src, attention config was declared in host;
* IsBoundary: whether to make boundary judgment
* @brief Read 2D data from global memory to registers for reduce.
*
* @template paraments
* T: The type of data stored in the global memory.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* Rank: The shape size of out. eg in[1, 35], out[32, 35] then shape size is 2.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x blockDim, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: Raw input data pointer of kernel.
* block_offset: Data offset of this block, blockDim.x * blockIdx.x * NX;
* index_cal: Calculation configuration of Reduce. It is used to calculate the
* coordinate mapping relationship between output data and input data. Please
* refer to the sample code for specific usage.
* block_offset: data offset of this block, blockDim.x * blockIdx.x * NX;
* index_cal: get the global index in src, attention config was declared in
* host;
* size_nx: number of columns to be processed by the current block
* size_ny: number of rows to be processed by the current block
* stride_nx: the stride of cols
* stride_ny: the stride of rows
* reduce_last_dim: according to the block split set threadIdx
* size_nx: The current block needs to load size_nx columns of data, this
* parameter will be used when IsBoundary = true.
* size_ny: The current block needs to load size_ny rows of data. This parameter
* will be used when IsBoundary = true.
* stride_nx: The stride of cols.
* stride_ny: The stride of rows.
* reduce_last_dim: Used to indicate whether the dimension of reduce contains
* the lowest dimension.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
,
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
Rank
,
typename
IndexCal
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataReduce
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
block_offset
,
...
...
@@ -397,17 +418,26 @@ __device__ __forceinline__ void ReadDataReduce(
}
/**
* @brief store data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* IsBoundary: whether to make boundary judgment
* @brief Write 2D data from registers to global memory. When IsBoundary = true
* and (NX % 4 == 0 or Nx % 2 == 0), the data will be vectorized to improve the
* data loading efficiency
*
* @template paraments
* T: The type of data.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* IsBoundary: Indicates whether to perform block access storage out-of-bounds
* judgment. When the number of data processed by the block is less than
* NX x NY x blockDim, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* num: number of columns to be processed by the current block
* dst: Data pointer of the current block.
* src: The register pointer of the thread, the size is NX * NY.
* size: The current block needs to load size data continuously.
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
WriteData
(
T
*
dst
,
T
*
__restrict__
src
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录