Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
12df57fb
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
12df57fb
编写于
9月 01, 2021
作者:
N
niuliling123
提交者:
GitHub
9月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ElementwiseTernary, Reduce, ReadDataStride (#35075)
* add ElementwiseTernary, Reduce, ReadDataStride
上级
d9afa839
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
525 addition
and
108 deletion
+525
-108
paddle/fluid/operators/kernel_primitives/compute_primitives.h
...le/fluid/operators/kernel_primitives/compute_primitives.h
+178
-44
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
.../fluid/operators/kernel_primitives/datamover_primitives.h
+287
-63
paddle/fluid/operators/kernel_primitives/helper_primitives.h
paddle/fluid/operators/kernel_primitives/helper_primitives.h
+60
-1
未找到文件。
paddle/fluid/operators/kernel_primitives/compute_primitives.h
浏览文件 @
12df57fb
...
...
@@ -21,7 +21,8 @@
#include <hip/hip_fp16.h>
#endif
#include <algorithm>
// #include <algorithm>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
...
...
@@ -29,6 +30,16 @@ namespace operators {
namespace
kernel_primitives
{
namespace
details
{
#ifdef __HIPCC__
constexpr
int
kMaxThread
=
256
;
constexpr
int
kWarpSize
=
64
;
#else
constexpr
int
kMaxThread
=
128
;
constexpr
int
kWarpSize
=
32
;
#endif
enum
ReduceMode
{
kGlobalMode
,
kLocalMode
};
template
<
typename
T
>
class
MPTypeTrait
{
public:
...
...
@@ -41,37 +52,98 @@ class MPTypeTrait<platform::float16> {
using
Type
=
float
;
};
}
// namespace details
/**
* @brief will be used in BlockYReduce, get the index of reduce_num in shared
* memory
*/
__device__
__forceinline__
int
SharedMemoryIndex
(
int
index
)
{
return
(
threadIdx
.
y
+
index
)
*
blockDim
.
x
+
threadIdx
.
x
;
}
/*************************** Compute Functor****************************/
template
<
typename
T
,
typename
Enable
=
void
>
struct
DivFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
/
args
[
1
];
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
WarpReduce
(
T
val
,
ReduceOp
reducer
)
{
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
stride
=
details
::
kWarpSize
/
2
;
stride
>
0
;
stride
>>=
1
)
{
T
temp
=
paddle
::
platform
::
CudaShuffleDownSync
(
mask
,
val
,
stride
);
val
=
reducer
(
val
,
temp
);
}
};
return
val
;
}
template
<
typename
T
>
struct
DivFunctor
<
T
,
typename
std
::
enable_if_t
<
std
::
is_integral
<
T
>::
value
>>
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
PADDLE_ENFORCE
(
args
[
1
]
!=
0
,
platform
::
errors
::
InvalidArgument
(
"Invalid Argument Error: Integer division by zero "
"encountered in divide. Please check the input value."
));
return
args
[
0
]
/
args
[
1
];
/* e.g.
* |---------block---------|
* |warp0|warp1|warp2|warp3|
* |0~31|32~63|64~95|96~127| ---->blockDim.x = 128
* \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp
* res0 res1 res2 res3 ---->2. Store result of each warp to shared memory
* \ \ / / ---->3. Load the result above from shared memory
* res to warp0 and process the second WarpReduce
*/
/**
* @brief BlockXReduce reduce along blockDim.x
*/
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
BlockXReduce
(
T
val
,
ReduceOp
reducer
)
{
__syncthreads
();
using
details
::
kWarpSize
;
__shared__
T
shared
[
2
*
kWarpSize
];
int
block_dim_x
=
blockDim
.
x
;
if
(
blockDim
.
x
>
kWarpSize
)
{
block_dim_x
=
blockDim
.
x
/
kWarpSize
;
int
lane
=
threadIdx
.
x
%
kWarpSize
;
int
tid
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
wid
=
tid
/
kWarpSize
;
int
bid
=
threadIdx
.
y
;
val
=
WarpReduce
(
val
,
reducer
);
if
(
lane
==
0
)
{
shared
[
wid
]
=
val
;
}
};
__syncthreads
();
val
=
shared
[
bid
*
block_dim_x
+
lane
];
}
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
for
(
int
stride
=
1
;
stride
<
block_dim_x
;
stride
<<=
1
)
{
T
temp
=
paddle
::
platform
::
CudaShuffleDownSync
(
mask
,
val
,
stride
);
val
=
reducer
(
val
,
temp
);
}
return
val
;
}
/**
* @brief BlockYReduce reduce along blockDim.y
*/
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
BlockYReduce
(
T
val
,
ReduceOp
reducer
)
{
__shared__
T
shared_memory
[
details
::
kMaxThread
];
shared_memory
[
SharedMemoryIndex
(
0
)]
=
val
;
for
(
int
stride
=
blockDim
.
y
/
2
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
if
(
threadIdx
.
y
<
stride
&&
threadIdx
.
y
+
stride
<
blockDim
.
y
)
{
T
temp
=
shared_memory
[
SharedMemoryIndex
(
stride
)];
val
=
reducer
(
val
,
temp
);
}
shared_memory
[
SharedMemoryIndex
(
0
)]
=
val
;
}
return
val
;
}
}
// namespace details
/*************************** Compute Function****************************/
/**
* @brief
compute functor for elementwise_two, in1 and in2 has th
e same shape
* @brief
binary function, in1 and in2 hav
e same shape
* @param:
* T : the type of in1 and in2
* NX: the row of in1 and in2
* NY: the col of in1 and in2
* BlockSize: the strid of col
* OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor eg: in1 + in2, in1 - in2
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
...
...
@@ -88,32 +160,40 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
}
/**
* @brief
fma eg: a * b + c, in1 in2, in3 and out has th
e same shape
* @brief
ternary function, in1, in2 and in3 hav
e same shape
* @param:
* T : the type of in1 and in2, in3
* NX: the row of in1, in2 and in3
* NY: the col of in1, in2 and in3
* BlockSize: the strid of col
* T: data type of in1, in2, in3
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor eg: out = in1 * in2 + in3
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
Elementwise
Fma
(
OutT
*
out
,
const
T
*
in1
,
__device__
__forceinline__
void
Elementwise
Ternary
(
OutT
*
out
,
const
T
*
in1
,
const
T
*
in2
,
const
T
*
in3
,
OpFunc
compute
)
{
T
args
[
3
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
++
idx
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
in1
[
idx
],
in2
[
idx
],
in3
[
idx
]));
args
[
0
]
=
in1
[
idx
];
args
[
1
]
=
in2
[
idx
];
args
[
2
]
=
in3
[
idx
];
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
(
args
));
}
}
/**
* @brief compute functor for elementwise_two, in1 is [1, NY], in2 is [NX, NY]
* @brief cycle binary function, in1's shape size is [1, NX], in2's shape size
* is [NY, NX], out's shape size is [NY, NX]
* @param:
* T : the type of in1 and in2
* NX: the row of in1 and in2
* NY: the col of in2
* BlockSize: the strid of col
* OpFunc: compute functor eg: ADD, SUB, XOR, OR, MUL
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor eg: in1 + in2, in1 - in2
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
...
...
@@ -130,13 +210,14 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
}
/**
* @brief
compute functor for unary, in1 is [NX, NY]
* @brief
unary function
* @param:
* T : the type of in
* NX: the row of in
* NY: the col of in
* BlockSize: the strid of col
* OpFunc: compute functor eg: relu, sigmoid, exp
* T: data type of in
* OutT: data type of out
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: compute functor eg: relu, exp
*/
template
<
typename
T
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
...
...
@@ -148,6 +229,59 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
}
}
/**
* @brief reduce function, in's shape size is [NX, NY].
* If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1],
* if ReduceMode == kGlobalMode then reduce between different threads, the
* shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was
* split, BlockYReduce will be called. If reduce_last_dim is true and
* reduce_num was split, BlockXReduce will be called
* @typename:
* T: data type of in
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: reduce functor, eg: CustomSum, CustomMean in reduce_functor_op.h
* @param:
* reducer: reduce functor, eg: CustomSum<T>()
* reduce_last_dim: if in's last dim need to be reduce then reduce_last_dim =
* true
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
,
details
::
ReduceMode
Mode
>
__device__
__forceinline__
void
Reduce
(
T
*
out
,
const
T
*
in
,
OpFunc
reducer
,
bool
reduce_last_dim
)
{
int
block_index
=
blockDim
.
y
;
if
(
Mode
==
details
::
ReduceMode
::
kGlobalMode
)
{
bool
block_reduce_y
=
(
!
reduce_last_dim
)
&&
(
block_index
>
1
);
// when reduce is not required for the last dim, and reduce num has been
// split into multiple threads
if
(
block_reduce_y
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
*
NX
;
i
++
)
{
// reduce along blockdim.y
out
[
i
]
=
details
::
BlockYReduce
<
T
,
OpFunc
>
(
out
[
i
],
reducer
);
}
}
// when last dimension need to be reduced
if
(
reduce_last_dim
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
*
NX
;
i
++
)
{
// reduce along blockDim.x
out
[
i
]
=
details
::
BlockXReduce
<
T
,
OpFunc
>
(
out
[
i
],
reducer
);
}
}
}
else
{
// else kLocalMode
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NX
;
++
j
)
{
out
[
i
]
=
reducer
(
out
[
i
],
in
[
i
*
NX
+
j
]);
}
}
}
}
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/kernel_primitives/datamover_primitives.h
浏览文件 @
12df57fb
...
...
@@ -13,11 +13,13 @@
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
#include <iostream>
#include <vector>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -104,52 +106,197 @@ struct BroadcastConfig {
#undef INT_BITS
}
// namespace details
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
ReadDataBase
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
size
)
{
/**
* @brief load data from src to dst, src can be 1D data or 2D data. Note that
* you can use this function when you are sure that the data will not cross the
* boundary.
* @typename:
* Tx: data type of src
* Ty: data type of dstt
* NX: the cols of src, dst
* NY: the rows of src, dst
* BlockSize: the config of this device
* @param:
* stride_nx: the stride of cols
* stride_ny: the stride of rows
*/
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
ReadData
(
Ty
*
dst
,
const
Tx
*
__restrict__
src
,
int
stride_nx
,
int
stride_ny
)
{
if
(
NY
==
1
&&
NX
==
1
)
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
threadIdx
.
x
]);
}
else
if
(
NX
==
1
)
{
int
dx
=
threadIdx
.
x
;
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
dx
+
idy
*
stride_ny
]);
}
}
else
if
(
NY
==
1
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
idx
*
stride_nx
]);
}
}
else
{
int
dx
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
idx
+
dx
)
>=
size
)
{
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
dst
[
idy
*
NX
+
idx
]
=
static_cast
<
Ty
>
(
src
[
idx
*
stride_nx
+
dx
+
idy
*
stride_ny
]);
}
}
}
}
/**
* @brief load data from src to dst, src can be 1D data or 2D data. When
* boundary judgment is required, you need to set a to true, and a is false by
* default.
* @typename:
* Tx: data type of src
* Ty: data type of dstt
* NX: the cols of src, dst
* NY: the rows of src, dst
* BlockSize: the config of this device
* IsBoundary: whether to make boundary judgment
* @param:
* size_nx: number of columns to be processed by the current block
* size_ny: number of rows to be processed by the current block
* stride_nx: the stride of cols
* stride_ny: the stride of rows
*/
template
<
typename
Tx
,
typename
Ty
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadData
(
Ty
*
dst
,
const
Tx
*
__restrict__
src
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
)
{
int
dx
=
threadIdx
.
x
*
NX
;
int
size
=
size_nx
-
dx
;
// Each branch is added for better performance
if
(
NX
==
1
&&
NY
==
1
)
{
// for NX == 1 and NY == 1
if
(
IsBoundary
)
{
if
(
dx
<
size_nx
)
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
dx
]);
}
}
else
{
dst
[
0
]
=
static_cast
<
Ty
>
(
src
[
dx
]);
}
}
else
if
(
NX
==
1
)
{
// for NX == 1 and NY != 1
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
if
(
IsBoundary
)
{
if
(
idy
>=
size_ny
)
{
break
;
}
dst
[
idx
]
=
src
[
idx
+
dx
];
}
dst
[
idy
]
=
static_cast
<
Ty
>
(
src
[
dx
+
idy
*
stride_ny
]);
}
}
else
if
(
NY
==
1
)
{
// for NY == 1 and NX != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
IsBoundary
)
{
if
(
idx
>=
size
)
{
break
;
}
}
dst
[
idx
]
=
static_cast
<
Ty
>
(
src
[
idx
*
stride_nx
+
dx
]);
}
}
else
{
// for NX != 1 and NY != 1
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
IsBoundary
)
{
if
(
idx
>=
size
)
{
break
;
}
}
#pragma unroll
for
(
int
idy
=
0
;
idy
<
NY
;
++
idy
)
{
if
(
IsBoundary
)
{
if
(
idy
>=
size_ny
)
{
break
;
}
}
dst
[
idy
*
NX
+
idx
]
=
static_cast
<
Ty
>
(
src
[
idx
*
stride_nx
+
dx
+
idy
*
stride_ny
]);
}
}
}
}
template
<
typename
T
,
int
NX
>
__device__
__forceinline__
void
Init
(
T
*
dst
,
T
init_data
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NX
;
i
++
)
{
dst
[
i
]
=
init_data
;
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
>
/** @brief: ReadData
* @brief load data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* IsBoundary: whether to make boundary judgment
* @param:
* num: number of columns to be processed by the current block
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadData
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
size
)
{
const
int
VECTOR_SIZE
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
VECTORS_PER_THREAD
=
NX
/
VECTOR_SIZE
;
int
num
)
{
if
(
IsBoundary
)
{
// blockDim.x * NX > num
int
dx
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
(
idx
+
dx
<
num
)
{
dst
[
idx
]
=
src
[
idx
+
dx
];
}
}
}
else
{
// blockDim,x * NX < num
const
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
tid
=
threadIdx
.
x
*
kVectorsPerThread
;
// Vector per thread
if
(
blockDim
.
x
*
NX
>
size
)
{
ReadDataBase
<
T
,
NX
,
NY
,
BlockSize
>
(
dst
,
src
,
size
);
}
else
{
// Vector type
using
VecType
=
details
::
VectorType
<
T
,
VECTOR_SIZE
>
;
VecType
vec_temp
[
VECTORS_PER_THREAD
];
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
const
VecType
*
vec_input
=
reinterpret_cast
<
const
VecType
*>
(
src
);
ReadDataBase
<
VecType
,
VECTORS_PER_THREAD
,
NY
,
BlockSize
>
(
vec_temp
,
vec_input
,
VECTORS_PER_THREAD
*
blockDim
.
x
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kVectorsPerThread
;
++
i
)
{
vec_temp
[
i
]
=
vec_input
[
i
+
tid
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
dst
[
idx
]
=
*
(
reinterpret_cast
<
T
*>
(
vec_temp
)
+
idx
);
}
}
}
}
/** @brief: ReadDataBc
* read data from src ptr when the shape of src and dst are different
/**
* @brief: read data for broadcast
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size
* is 2
* IsBoundary: whether to make boundary judgment
* @param:
*
src: the source pointer
*
dst: the dst pointer
*
stride_nx: the stride of src
* stride_n
y: the stride of src
*
the shape of dst is [NY, NX]
*
fix: data offset of this block, blockDim.x * blockIdx.x * NX;
*
config: get the global index in src, attention config was declared in host;
*
num: the num of out
* stride_n
x: the stride of cols
*
stride_ny: the stride of rows
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
>
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataBc
(
T
*
dst
,
const
T
*
__restrict__
src
,
uint32_t
fix
,
details
::
BroadcastConfig
<
ShapeSize
>
config
,
int
num
,
int
stride_nx
,
...
...
@@ -162,7 +309,11 @@ __device__ __forceinline__ void ReadDataBc(
#pragma unroll
for
(
uint32_t
nx
=
0
;
nx
<
NX
;
++
nx
)
{
uint32_t
idx
=
base_offset
+
ny
*
stride_ny
+
nx
*
stride_nx
;
if
(
idx
<
num
)
{
if
(
IsBoundary
)
{
if
(
idx
>=
num
)
{
break
;
}
}
offset
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ShapeSize
;
++
i
)
{
...
...
@@ -173,42 +324,115 @@ __device__ __forceinline__ void ReadDataBc(
dst
[
nx
+
ny
*
NX
]
=
src
[
offset
];
}
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
>
__device__
__forceinline__
void
WriteDataBase
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
size
)
{
int
dx
=
threadIdx
.
x
*
NX
;
/**
* @brief: read data for broadcast
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* ShapeSize: the shape size of out. eg in[1, 35], out[32, 35] then shape size
* is 2
* IndexCal: get the global index in src, attention config was declared in host;
* IsBoundary: whether to make boundary judgment
* @param:
* fix: data offset of this block, blockDim.x * blockIdx.x * NX;
* index_cal: get the global index in src, attention config was declared in
* host;
* size_nx: number of columns to be processed by the current block
* size_ny: number of rows to be processed by the current block
* stride_nx: the stride of cols
* stride_ny: the stride of rows
* reduce_last_dim: according to the block split set threadIdx
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
int
ShapeSize
,
typename
IndexCal
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataReduce
(
T
*
dst
,
const
T
*
__restrict__
src
,
int
fix
,
const
IndexCal
&
index_cal
,
int
size_nx
,
int
size_ny
,
int
stride_nx
,
int
stride_ny
,
bool
reduce_last_dim
)
{
int
base_offset
=
fix
;
if
(
reduce_last_dim
)
{
base_offset
+=
threadIdx
.
x
;
}
else
{
base_offset
+=
threadIdx
.
y
;
}
if
(
NX
==
1
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
idx
+
dx
)
>=
size
)
{
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
if
(
IsBoundary
)
{
if
(
base_offset
>=
size_ny
)
{
break
;
}
dst
[
idx
+
dx
]
=
src
[
idx
];
}
uint32_t
offset
=
index_cal
(
base_offset
);
dst
[
ny
]
=
src
[
offset
];
base_offset
+=
stride_ny
;
}
}
else
{
#pragma unroll
for
(
int
nx
=
0
;
nx
<
NX
;
++
nx
)
{
if
(
IsBoundary
)
{
if
(
nx
*
stride_nx
>=
size_nx
)
{
break
;
}
}
#pragma unroll
for
(
int
ny
=
0
;
ny
<
NY
;
++
ny
)
{
if
(
IsBoundary
)
{
if
(
nx
*
stride_nx
>=
size_nx
)
{
break
;
}
}
uint32_t
offset
=
index_cal
(
base_offset
);
dst
[
nx
+
ny
*
NX
]
=
src
[
offset
];
base_offset
+=
stride_ny
;
}
}
}
}
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
>
/** @brief: WriteData
* @brief store data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false
* by default.
* @typename:
* T : the data type of src
* NX: the cols of src, dst
* NY: in this function NY only can be 1
* BlockSize: the config of this device
* IsBoundary: whether to make boundary judgment
* @param:
* num: number of columns to be processed by the current block
*/
template
<
typename
T
,
int
NX
,
int
NY
,
int
BlockSize
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
WriteData
(
T
*
dst
,
T
*
__restrict__
src
,
int
size
)
{
const
int
VECTOR_SIZE
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
VECTORS_PER_THREAD
=
NX
/
VECTOR_SIZE
;
// Vector per thread
if
(
blockDim
.
x
*
NX
>
size
)
{
WriteDataBase
<
T
,
NX
,
NY
,
BlockSize
>
(
dst
,
src
,
size
);
int
num
)
{
if
(
IsBoundary
)
{
int
dx
=
threadIdx
.
x
*
NX
;
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
;
++
idx
)
{
if
((
idx
+
dx
)
<
num
)
{
dst
[
idx
+
dx
]
=
src
[
idx
];
}
}
}
else
{
// Vector type
using
VecType
=
details
::
VectorType
<
T
,
VECTOR_SIZE
>
;
VecType
vec_temp
[
VECTORS_PER_THREAD
];
const
int
kVectorSize
=
(
NX
%
4
==
0
)
?
4
:
(
NX
%
2
==
0
)
?
2
:
1
;
const
int
kVectorsPerThread
=
NX
/
kVectorSize
;
int
dx
=
threadIdx
.
x
*
kVectorsPerThread
;
using
VecType
=
details
::
VectorType
<
T
,
kVectorSize
>
;
VecType
*
vec_dst
=
reinterpret_cast
<
VecType
*>
(
dst
);
VecType
vec_temp
[
kVectorsPerThread
];
#pragma unroll
for
(
int
idx
=
0
;
idx
<
VECTORS_PER_THREAD
;
++
idx
)
{
for
(
int
idx
=
0
;
idx
<
kVectorsPerThread
;
++
idx
)
{
vec_temp
[
idx
]
=
*
(
reinterpret_cast
<
VecType
*>
(
src
)
+
idx
);
vec_dst
[
dx
+
idx
]
=
vec_temp
[
idx
];
}
VecType
*
vec_dst
=
reinterpret_cast
<
VecType
*>
(
dst
);
WriteDataBase
<
VecType
,
VECTORS_PER_THREAD
,
NY
,
BlockSize
>
(
vec_dst
,
vec_temp
,
VECTORS_PER_THREAD
*
blockDim
.
x
);
}
}
...
...
paddle/fluid/operators/kernel_primitives/helper_primitives.h
浏览文件 @
12df57fb
...
...
@@ -16,6 +16,65 @@
namespace
paddle
{
namespace
operators
{
namespace
kernel_primitives
{}
namespace
kernel_primitives
{
namespace
details
{
static
__device__
__forceinline__
platform
::
float16
ExpFunctor
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
exp
(
x
);
}
static
__device__
__forceinline__
float
ExpFunctor
(
float
x
)
{
return
expf
(
x
);
}
static
__device__
__forceinline__
double
ExpFunctor
(
double
x
)
{
return
exp
(
x
);
}
static
__device__
__forceinline__
platform
::
float16
LogFunctor
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
log
(
x
);
}
static
__device__
__forceinline__
float
LogFunctor
(
float
x
)
{
return
logf
(
x
);
}
static
__device__
__forceinline__
double
LogFunctor
(
double
x
)
{
return
log
(
x
);
}
}
// namespace details
/*************************** Compute Functor****************************/
// for margin_cross_entropy
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpLogitTransformer
{
HOSTDEVICE
explicit
inline
ExpLogitTransformer
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
*
x
)
const
{
return
static_cast
<
Ty
>
(
details
::
ExpFunctor
(
x
[
0
]));
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
details
::
ExpFunctor
(
x
));
}
};
// Post processing function for sum, max, min, prod, any
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
*
x
)
const
{
return
static_cast
<
Ty
>
(
x
[
0
]);
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
);
}
};
// Post processing function for mean
template
<
typename
T
>
struct
DivideFunctor
{
HOSTDEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
T
)(
1.0
/
n
))
{}
HOSTDEVICE
inline
T
operator
()(
const
T
*
x
)
const
{
return
x
[
0
]
*
n_inv
;
}
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
*
n_inv
;
}
private:
T
n_inv
;
};
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录